Commit e270f68a authored by Jared Casper's avatar Jared Casper
Browse files

Addressed MR comments, mostly adding comments to code.

parent 64a83fb5
......@@ -124,8 +124,8 @@ def post_language_model_processing(lm_output, pooled_output,
class BertModel(MegatronModule):
"""Bert Language model."""
def __init__(self,
num_tokentypes=2,
def __init__(self,
num_tokentypes=2,
add_binary_head=True,
parallel_output=True,
pre_process=True,
......@@ -165,6 +165,7 @@ class BertModel(MegatronModule):
self._binary_head_key = 'binary_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, bert_model_input, attention_mask,
......
......@@ -30,8 +30,8 @@ from .module import MegatronModule
class Classification(MegatronModule):
def __init__(self,
num_classes,
def __init__(self,
num_classes,
num_tokentypes=2,
pre_process=True,
post_process=True):
......@@ -62,6 +62,7 @@ class Classification(MegatronModule):
self._classification_head_key = 'classification_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, model_input, attention_mask, tokentype_ids=None):
......
......@@ -86,6 +86,7 @@ class GPTModel(MegatronModule):
self.initialize_word_embeddings(init_method_normal)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask, labels=None,
......
......@@ -329,6 +329,7 @@ class TransformerLanguageModel(MegatronModule):
self._pooler_key = 'pooler'
def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
self.encoder.set_input_tensor(input_tensor)
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
......
......@@ -59,6 +59,7 @@ class MultipleChoice(MegatronModule):
self._multichoice_head_key = 'multichoice_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, model_input, attention_mask, tokentype_ids=None):
......
......@@ -576,11 +576,12 @@ class ParallelTransformer(MegatronModule):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(mpu.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
......@@ -620,6 +621,13 @@ class ParallelTransformer(MegatronModule):
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask, layer_past=None,
......@@ -644,6 +652,7 @@ class ParallelTransformer(MegatronModule):
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
else:
# See set_input_tensor()
hidden_states = self.input_tensor
if encoder_output is not None:
......
......@@ -204,11 +204,11 @@ def get_model(model_provider_func):
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
m = model_provider_func(
this_model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
model.append(m)
model.append(this_model)
else:
model = model_provider_func(
pre_process=pre_process,
......@@ -249,7 +249,7 @@ def get_model(model_provider_func):
process_group=mpu.get_data_parallel_group())
for model_module in model]
return model
if args.DDP_impl == 'local':
model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32,
......
......@@ -89,6 +89,11 @@ def calculate_correct_answers(name, model, dataloader,
ds = dataloader.dataset
if hasattr(ds, 'sample_multiplier'):
# If our dataset as a sample_multiplier attribute that means
# each "sample" from the dataset actually has multiple samples
# that will collapse into the batch dimension (for example in
# the RACE dataset that has several options), we need to
# account for that when setting the micro batch size.
sample_multiplier = ds.sample_multiplier
else:
sample_multiplier = 1
......
......@@ -137,6 +137,11 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
args.orig_micro_batch_size = args.micro_batch_size
args.orig_global_batch_size = args.global_batch_size
if hasattr(train_dataset, 'sample_multiplier'):
# If our dataset as a sample_multiplier attribute that means
# each "sample" from the dataset actually has multiple samples
# that will collapse into the batch dimension (for example in
# the RACE dataset that has several options), we need to
# account for that when setting the micro batch size.
args.micro_batch_size *= train_dataset.sample_multiplier
args.global_batch_size *= train_dataset.sample_multiplier
......
......@@ -39,6 +39,8 @@ class RaceDataset(Dataset):
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
# This indicates that each "sample" has multiple samples that
# will collapse into batch dimension
self.sample_multiplier = NUM_CHOICES
def __len__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment