"tests/vscode:/vscode.git/clone" did not exist on "f1f53d0fea684b698aa4dbdae4812a051dc07da0"
Commit e270f68a authored by Jared Casper's avatar Jared Casper
Browse files

Addressed MR comments, mostly adding comments to code.

parent 64a83fb5
...@@ -124,8 +124,8 @@ def post_language_model_processing(lm_output, pooled_output, ...@@ -124,8 +124,8 @@ def post_language_model_processing(lm_output, pooled_output,
class BertModel(MegatronModule): class BertModel(MegatronModule):
"""Bert Language model.""" """Bert Language model."""
def __init__(self, def __init__(self,
num_tokentypes=2, num_tokentypes=2,
add_binary_head=True, add_binary_head=True,
parallel_output=True, parallel_output=True,
pre_process=True, pre_process=True,
...@@ -165,6 +165,7 @@ class BertModel(MegatronModule): ...@@ -165,6 +165,7 @@ class BertModel(MegatronModule):
self._binary_head_key = 'binary_head' self._binary_head_key = 'binary_head'
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor) self.language_model.set_input_tensor(input_tensor)
def forward(self, bert_model_input, attention_mask, def forward(self, bert_model_input, attention_mask,
......
...@@ -30,8 +30,8 @@ from .module import MegatronModule ...@@ -30,8 +30,8 @@ from .module import MegatronModule
class Classification(MegatronModule): class Classification(MegatronModule):
def __init__(self, def __init__(self,
num_classes, num_classes,
num_tokentypes=2, num_tokentypes=2,
pre_process=True, pre_process=True,
post_process=True): post_process=True):
...@@ -62,6 +62,7 @@ class Classification(MegatronModule): ...@@ -62,6 +62,7 @@ class Classification(MegatronModule):
self._classification_head_key = 'classification_head' self._classification_head_key = 'classification_head'
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor) self.language_model.set_input_tensor(input_tensor)
def forward(self, model_input, attention_mask, tokentype_ids=None): def forward(self, model_input, attention_mask, tokentype_ids=None):
......
...@@ -86,6 +86,7 @@ class GPTModel(MegatronModule): ...@@ -86,6 +86,7 @@ class GPTModel(MegatronModule):
self.initialize_word_embeddings(init_method_normal) self.initialize_word_embeddings(init_method_normal)
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor) self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask, labels=None, def forward(self, input_ids, position_ids, attention_mask, labels=None,
......
...@@ -329,6 +329,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -329,6 +329,7 @@ class TransformerLanguageModel(MegatronModule):
self._pooler_key = 'pooler' self._pooler_key = 'pooler'
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
self.encoder.set_input_tensor(input_tensor) self.encoder.set_input_tensor(input_tensor)
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
......
...@@ -59,6 +59,7 @@ class MultipleChoice(MegatronModule): ...@@ -59,6 +59,7 @@ class MultipleChoice(MegatronModule):
self._multichoice_head_key = 'multichoice_head' self._multichoice_head_key = 'multichoice_head'
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor) self.language_model.set_input_tensor(input_tensor)
def forward(self, model_input, attention_mask, tokentype_ids=None): def forward(self, model_input, attention_mask, tokentype_ids=None):
......
...@@ -576,11 +576,12 @@ class ParallelTransformer(MegatronModule): ...@@ -576,11 +576,12 @@ class ParallelTransformer(MegatronModule):
# Stage 0: [0, 1] [4, 5] # Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7] # Stage 1: [2, 3] [6, 7]
offset = mpu.get_virtual_pipeline_model_parallel_rank() * ( offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \ args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(mpu.get_pipeline_model_parallel_rank() * self.num_layers) (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
else: else:
# Each stage gets a contiguous set of layers. # Each stage gets a contiguous set of layers.
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
...@@ -620,6 +621,13 @@ class ParallelTransformer(MegatronModule): ...@@ -620,6 +621,13 @@ class ParallelTransformer(MegatronModule):
return hidden_states return hidden_states
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
...@@ -644,6 +652,7 @@ class ParallelTransformer(MegatronModule): ...@@ -644,6 +652,7 @@ class ParallelTransformer(MegatronModule):
else: else:
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
else: else:
# See set_input_tensor()
hidden_states = self.input_tensor hidden_states = self.input_tensor
if encoder_output is not None: if encoder_output is not None:
......
...@@ -204,11 +204,11 @@ def get_model(model_provider_func): ...@@ -204,11 +204,11 @@ def get_model(model_provider_func):
model = [] model = []
for i in range(args.virtual_pipeline_model_parallel_size): for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i) mpu.set_virtual_pipeline_model_parallel_rank(i)
m = model_provider_func( this_model = model_provider_func(
pre_process=pre_process, pre_process=pre_process,
post_process=post_process post_process=post_process
) )
model.append(m) model.append(this_model)
else: else:
model = model_provider_func( model = model_provider_func(
pre_process=pre_process, pre_process=pre_process,
...@@ -249,7 +249,7 @@ def get_model(model_provider_func): ...@@ -249,7 +249,7 @@ def get_model(model_provider_func):
process_group=mpu.get_data_parallel_group()) process_group=mpu.get_data_parallel_group())
for model_module in model] for model_module in model]
return model return model
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
model = [LocalDDP(model_module, model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32, args.accumulate_allreduce_grads_in_fp32,
......
...@@ -89,6 +89,11 @@ def calculate_correct_answers(name, model, dataloader, ...@@ -89,6 +89,11 @@ def calculate_correct_answers(name, model, dataloader,
ds = dataloader.dataset ds = dataloader.dataset
if hasattr(ds, 'sample_multiplier'): if hasattr(ds, 'sample_multiplier'):
# If our dataset as a sample_multiplier attribute that means
# each "sample" from the dataset actually has multiple samples
# that will collapse into the batch dimension (for example in
# the RACE dataset that has several options), we need to
# account for that when setting the micro batch size.
sample_multiplier = ds.sample_multiplier sample_multiplier = ds.sample_multiplier
else: else:
sample_multiplier = 1 sample_multiplier = 1
......
...@@ -137,6 +137,11 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset): ...@@ -137,6 +137,11 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
args.orig_micro_batch_size = args.micro_batch_size args.orig_micro_batch_size = args.micro_batch_size
args.orig_global_batch_size = args.global_batch_size args.orig_global_batch_size = args.global_batch_size
if hasattr(train_dataset, 'sample_multiplier'): if hasattr(train_dataset, 'sample_multiplier'):
# If our dataset as a sample_multiplier attribute that means
# each "sample" from the dataset actually has multiple samples
# that will collapse into the batch dimension (for example in
# the RACE dataset that has several options), we need to
# account for that when setting the micro batch size.
args.micro_batch_size *= train_dataset.sample_multiplier args.micro_batch_size *= train_dataset.sample_multiplier
args.global_batch_size *= train_dataset.sample_multiplier args.global_batch_size *= train_dataset.sample_multiplier
......
...@@ -39,6 +39,8 @@ class RaceDataset(Dataset): ...@@ -39,6 +39,8 @@ class RaceDataset(Dataset):
print_rank_0(' >> total number of samples: {}'.format( print_rank_0(' >> total number of samples: {}'.format(
len(self.samples))) len(self.samples)))
# This indicates that each "sample" has multiple samples that
# will collapse into batch dimension
self.sample_multiplier = NUM_CHOICES self.sample_multiplier = NUM_CHOICES
def __len__(self): def __len__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment