Commit 3b91262e authored by Vijay Korthikanti's avatar Vijay Korthikanti Committed by Jared Casper
Browse files

pipeline code simplification

parent 2f3a2d68
...@@ -16,15 +16,7 @@ ...@@ -16,15 +16,7 @@
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
from .distributed import * from .distributed import *
from .bert_model import (BertModel, from .bert_model import BertModel
BertModelFirstStage, from .gpt_model import GPTModel
BertModelIntermediateStage,
BertModelLastStage)
from .gpt_model import (GPTModel,
GPTModelFirstStage,
GPTModelIntermediateStage,
GPTModelLastStage)
from .language_model import get_language_model from .language_model import get_language_model
from .module import Float16Module from .module import Float16Module
...@@ -121,17 +121,23 @@ def post_language_model_processing(lm_output, pooled_output, ...@@ -121,17 +121,23 @@ def post_language_model_processing(lm_output, pooled_output,
return lm_loss, binary_logits return lm_loss, binary_logits
class BertModelBase(MegatronModule): class BertModel(MegatronModule):
"""Bert Language model.""" """Bert Language model."""
def __init__(self, num_tokentypes=2, add_binary_head=True, def __init__(self,
parallel_output=True): num_tokentypes=2,
super(BertModelBase, self).__init__() add_binary_head=True,
parallel_output=True,
pre_process=True,
post_process=True):
super(BertModel, self).__init__()
args = get_args() args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head self.add_binary_head = add_binary_head
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std, scaled_init_method = scaled_init_method_normal(args.init_method_std,
...@@ -142,10 +148,12 @@ class BertModelBase(MegatronModule): ...@@ -142,10 +148,12 @@ class BertModelBase(MegatronModule):
add_pooler=self.add_binary_head, add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
self.initialize_word_embeddings(init_method_normal) self.initialize_word_embeddings(init_method_normal)
if mpu.is_pipeline_last_stage(): if self.post_process:
self.lm_head = BertLMHead( self.lm_head = BertLMHead(
self.word_embeddings_weight().size(0), self.word_embeddings_weight().size(0),
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output) args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
...@@ -156,26 +164,29 @@ class BertModelBase(MegatronModule): ...@@ -156,26 +164,29 @@ class BertModelBase(MegatronModule):
init_method) init_method)
self._binary_head_key = 'binary_head' self._binary_head_key = 'binary_head'
def set_input_tensor(self, input_tensor):
self.language_model.set_input_tensor(input_tensor)
def forward(self, bert_model_input, attention_mask, def forward(self, bert_model_input, attention_mask,
tokentype_ids=None, lm_labels=None): tokentype_ids=None, lm_labels=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask) extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = bert_model_input
position_ids = bert_position_ids(input_ids)
kwargs = {} lm_output = self.language_model(
if mpu.is_pipeline_first_stage(): input_ids,
input_ids = bert_model_input position_ids,
position_ids = bert_position_ids(input_ids) extended_attention_mask,
args = [input_ids, position_ids, extended_attention_mask] tokentype_ids=tokentype_ids
kwargs['tokentype_ids'] = tokentype_ids )
else:
args = [bert_model_input, extended_attention_mask] if self.post_process and self.add_binary_head:
lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage() and self.add_binary_head:
lm_output, pooled_output = lm_output lm_output, pooled_output = lm_output
else: else:
pooled_output = None pooled_output = None
if mpu.is_pipeline_last_stage(): if self.post_process:
return post_language_model_processing(lm_output, pooled_output, return post_language_model_processing(lm_output, pooled_output,
self.lm_head, self.binary_head, self.lm_head, self.binary_head,
lm_labels, lm_labels,
...@@ -194,15 +205,15 @@ class BertModelBase(MegatronModule): ...@@ -194,15 +205,15 @@ class BertModelBase(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage(): if self.post_process:
state_dict_[self._lm_head_key] \ state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint( = self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage() and self.add_binary_head: if self.post_process and self.add_binary_head:
state_dict_[self._binary_head_key] \ state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars) = self.binary_head.state_dict(destination, prefix, keep_vars)
# Save word_embeddings. # Save word_embeddings.
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] \ state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars) = self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
...@@ -212,74 +223,13 @@ class BertModelBase(MegatronModule): ...@@ -212,74 +223,13 @@ class BertModelBase(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
if mpu.is_pipeline_last_stage(): if self.post_process:
self.lm_head.load_state_dict( self.lm_head.load_state_dict(
state_dict[self._lm_head_key], strict=strict) state_dict[self._lm_head_key], strict=strict)
if mpu.is_pipeline_last_stage() and self.add_binary_head: if self.post_process and self.add_binary_head:
self.binary_head.load_state_dict( self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict) state_dict[self._binary_head_key], strict=strict)
# Load word_embeddings. # Load word_embeddings.
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process:
self.word_embeddings.load_state_dict( self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict) state_dict[self._word_embeddings_for_head_key], strict=strict)
class BertModel(BertModelBase):
def __init__(self, num_tokentypes=2, add_binary_head=True,
parallel_output=True):
super(BertModel, self).__init__(
num_tokentypes=num_tokentypes,
add_binary_head=add_binary_head,
parallel_output=parallel_output)
def forward(self, input_ids, attention_mask,
tokentype_ids=None, lm_labels=None):
return super(BertModel, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids,
lm_labels=lm_labels)
class BertModelFirstStage(BertModelBase):
def __init__(self, num_tokentypes=2):
super(BertModelFirstStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(BertModelFirstStage, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class BertModelIntermediateStage(BertModelBase):
def __init__(self, num_tokentypes=2):
super(BertModelIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(BertModelIntermediateStage, self).forward(
hidden_state,
attention_mask)
class BertModelLastStage(BertModelBase):
def __init__(self, num_tokentypes=2, add_binary_head=True,
parallel_output=True):
super(BertModelLastStage, self).__init__(
num_tokentypes=num_tokentypes,
add_binary_head=add_binary_head,
parallel_output=parallel_output)
def forward(self, hidden_state, attention_mask,
lm_labels=None):
return super(BertModelLastStage, self).forward(
hidden_state,
attention_mask,
lm_labels=lm_labels)
...@@ -28,13 +28,19 @@ from megatron.model.utils import scaled_init_method_normal ...@@ -28,13 +28,19 @@ from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule from .module import MegatronModule
class ClassificationBase(MegatronModule): class Classification(MegatronModule):
def __init__(self, num_classes, num_tokentypes=2): def __init__(self,
super(ClassificationBase, self).__init__(share_word_embeddings=False) num_classes,
num_tokentypes=2,
pre_process=True,
post_process=True):
super(Classification, self).__init__(share_word_embeddings=False)
args = get_args() args = get_args()
self.num_classes = num_classes self.num_classes = num_classes
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
...@@ -43,31 +49,35 @@ class ClassificationBase(MegatronModule): ...@@ -43,31 +49,35 @@ class ClassificationBase(MegatronModule):
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)
# Multi-choice head. # Multi-choice head.
if mpu.is_pipeline_last_stage(): if self.post_process:
self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
self.classification_head = get_linear_layer(args.hidden_size, self.classification_head = get_linear_layer(args.hidden_size,
self.num_classes, self.num_classes,
init_method) init_method)
self._classification_head_key = 'classification_head' self._classification_head_key = 'classification_head'
def set_input_tensor(self, input_tensor):
self.language_model.set_input_tensor(input_tensor)
def forward(self, model_input, attention_mask, tokentype_ids=None): def forward(self, model_input, attention_mask, tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask) extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = model_input
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids
)
kwargs = {} if self.post_process:
if mpu.is_pipeline_first_stage():
input_ids = model_input
position_ids = bert_position_ids(input_ids)
args = [input_ids, position_ids, extended_attention_mask]
kwargs['tokentype_ids'] = tokentype_ids
else:
args = [model_input, extended_attention_mask]
lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage():
_, pooled_output = lm_output _, pooled_output = lm_output
classification_output = self.classification_dropout(pooled_output) classification_output = self.classification_dropout(pooled_output)
classification_logits = self.classification_head(classification_output) classification_logits = self.classification_head(classification_output)
...@@ -87,7 +97,7 @@ class ClassificationBase(MegatronModule): ...@@ -87,7 +97,7 @@ class ClassificationBase(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage(): if self.post_process:
state_dict_[self._classification_head_key] \ state_dict_[self._classification_head_key] \
= self.classification_head.state_dict( = self.classification_head.state_dict(
destination, prefix, keep_vars) destination, prefix, keep_vars)
...@@ -98,7 +108,7 @@ class ClassificationBase(MegatronModule): ...@@ -98,7 +108,7 @@ class ClassificationBase(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
if mpu.is_pipeline_last_stage(): if self.post_process:
if self._classification_head_key in state_dict: if self._classification_head_key in state_dict:
self.classification_head.load_state_dict( self.classification_head.load_state_dict(
state_dict[self._classification_head_key], strict=strict) state_dict[self._classification_head_key], strict=strict)
...@@ -106,55 +116,3 @@ class ClassificationBase(MegatronModule): ...@@ -106,55 +116,3 @@ class ClassificationBase(MegatronModule):
print_rank_last('***WARNING*** could not find {} in the checkpoint, ' print_rank_last('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format( 'initializing to random'.format(
self._classification_head_key)) self._classification_head_key))
class Classification(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(Classification, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(Classification, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class ClassificationFirstStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationFirstStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(ClassificationFirstStage, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class ClassificationIntermediateStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationIntermediateStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(ClassificationIntermediateStage, self).forward(
hidden_state,
attention_mask)
class ClassificationLastStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationLastStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(ClassificationLastStage, self).forward(
hidden_state,
attention_mask)
...@@ -57,14 +57,20 @@ def post_language_model_processing(lm_output, labels, logit_weights, ...@@ -57,14 +57,20 @@ def post_language_model_processing(lm_output, labels, logit_weights,
return loss return loss
class GPTModelBase(MegatronModule): class GPTModel(MegatronModule):
"""GPT-2 Language model.""" """GPT-2 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self,
super(GPTModelBase, self).__init__() num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True):
super(GPTModel, self).__init__()
args = get_args() args = get_args()
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
...@@ -73,24 +79,27 @@ class GPTModelBase(MegatronModule): ...@@ -73,24 +79,27 @@ class GPTModelBase(MegatronModule):
encoder_attn_mask_type=AttnMaskType.causal, encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std), init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)
self.initialize_word_embeddings(init_method_normal) self.initialize_word_embeddings(init_method_normal)
def forward(self, gpt_model_input, attention_mask, labels=None, def set_input_tensor(self, input_tensor):
self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value} lm_output = self.language_model(
if mpu.is_pipeline_first_stage(): input_ids,
(input_ids, position_ids) = gpt_model_input position_ids,
args = [input_ids, position_ids, attention_mask] attention_mask,
kwargs['tokentype_ids'] = tokentype_ids layer_past=layer_past,
else: get_key_value=get_key_value)
args = [gpt_model_input, attention_mask]
lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage(): if self.post_process:
return post_language_model_processing( return post_language_model_processing(
lm_output, labels, lm_output, labels,
self.word_embeddings_weight(), self.word_embeddings_weight(),
...@@ -109,7 +118,7 @@ class GPTModelBase(MegatronModule): ...@@ -109,7 +118,7 @@ class GPTModelBase(MegatronModule):
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
# Save word_embeddings. # Save word_embeddings.
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] \ state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars) = self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
...@@ -118,79 +127,9 @@ class GPTModelBase(MegatronModule): ...@@ -118,79 +127,9 @@ class GPTModelBase(MegatronModule):
"""Customized load.""" """Customized load."""
# Load word_embeddings. # Load word_embeddings.
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process:
self.word_embeddings.load_state_dict( self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict) state_dict[self._word_embeddings_for_head_key], strict=strict)
if self._language_model_key in state_dict: if self._language_model_key in state_dict:
state_dict = state_dict[self._language_model_key] state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict) self.language_model.load_state_dict(state_dict, strict=strict)
class GPTModel(GPTModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPTModel, self).__init__(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
return super(GPTModel, self).forward(
(input_ids, position_ids),
attention_mask,
labels=labels,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
class GPTModelFirstStage(GPTModelBase):
def __init__(self, num_tokentypes=0):
super(GPTModelFirstStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
return super(GPTModelFirstStage, self).forward(
(input_ids, position_ids),
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value)
class GPTModelIntermediateStage(GPTModelBase):
def __init__(self, num_tokentypes=0):
super(GPTModelIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask,
layer_past=None, get_key_value=False):
return super(GPTModelIntermediateStage, self).forward(
hidden_state,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
class GPTModelLastStage(GPTModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPTModelLastStage, self).__init__(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
def forward(self, hidden_state, attention_mask, labels=None,
layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
return super(GPTModelLastStage, self).forward(
hidden_state,
attention_mask,
labels=labels,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
...@@ -46,7 +46,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -46,7 +46,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
encoder_attn_mask_type, init_method=None, encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False, scaled_init_method=None, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal): decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True, post_process=True):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
args = get_args() args = get_args()
...@@ -58,26 +59,17 @@ def get_language_model(num_tokentypes, add_pooler, ...@@ -58,26 +59,17 @@ def get_language_model(num_tokentypes, add_pooler,
args.num_layers) args.num_layers)
# Language model. # Language model.
args = [init_method, scaled_init_method, encoder_attn_mask_type] language_model = TransformerLanguageModel(
kwargs = {} init_method,
cls = None scaled_init_method,
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage(): encoder_attn_mask_type,
cls = TransformerLanguageModel num_tokentypes=num_tokentypes,
kwargs['num_tokentypes'] = num_tokentypes add_decoder=add_decoder,
kwargs['add_decoder'] = add_decoder decoder_attn_mask_type=decoder_attn_mask_type,
kwargs['decoder_attn_mask_type'] = decoder_attn_mask_type add_pooler=add_pooler,
kwargs['add_pooler'] = add_pooler pre_process=pre_process,
elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage(): post_process=post_process
cls = TransformerLanguageModelFirstStage )
kwargs['num_tokentypes'] = num_tokentypes
elif not mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelLastStage
kwargs['add_pooler'] = add_pooler
else:
cls = TransformerLanguageModelIntermediateStage
# Language model.
language_model = cls(*args, **kwargs)
# key used for checkpoints. # key used for checkpoints.
language_model_key = 'language_model' language_model_key = 'language_model'
...@@ -263,7 +255,7 @@ class Embedding(MegatronModule): ...@@ -263,7 +255,7 @@ class Embedding(MegatronModule):
'checkpoint but could not find it', flush=True) 'checkpoint but could not find it', flush=True)
class TransformerLanguageModelBase(MegatronModule): class TransformerLanguageModel(MegatronModule):
"""Transformer language model. """Transformer language model.
Arguments: Arguments:
...@@ -283,10 +275,14 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -283,10 +275,14 @@ class TransformerLanguageModelBase(MegatronModule):
num_tokentypes=0, num_tokentypes=0,
add_decoder=False, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal, decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False): add_pooler=False,
super(TransformerLanguageModelBase, self).__init__() pre_process=True,
post_process=True):
super(TransformerLanguageModel, self).__init__()
args = get_args() args = get_args()
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
self.init_method = init_method self.init_method = init_method
...@@ -296,7 +292,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -296,7 +292,7 @@ class TransformerLanguageModelBase(MegatronModule):
self.add_pooler = add_pooler self.add_pooler = add_pooler
# Embeddings. # Embeddings.
if mpu.is_pipeline_first_stage(): if self.pre_process:
self.embedding = Embedding(self.hidden_size, self.embedding = Embedding(self.hidden_size,
args.padded_vocab_size, args.padded_vocab_size,
args.max_position_embeddings, args.max_position_embeddings,
...@@ -309,7 +305,10 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -309,7 +305,10 @@ class TransformerLanguageModelBase(MegatronModule):
self.encoder = ParallelTransformer( self.encoder = ParallelTransformer(
self.init_method, self.init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type) self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process
)
self._encoder_key = 'encoder' self._encoder_key = 'encoder'
# Decoder # Decoder
...@@ -323,26 +322,28 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -323,26 +322,28 @@ class TransformerLanguageModelBase(MegatronModule):
self_attn_mask_type=self.decoder_attn_mask_type) self_attn_mask_type=self.decoder_attn_mask_type)
self._decoder_key = 'decoder' self._decoder_key = 'decoder'
if mpu.is_pipeline_last_stage(): if self.post_process:
# Pooler. # Pooler.
if self.add_pooler: if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method) self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler' self._pooler_key = 'pooler'
def forward(self, enc_language_model_input, enc_attn_mask, def set_input_tensor(self, input_tensor):
dec_language_model_input=None, dec_attn_mask=None, self.encoder.set_input_tensor(input_tensor)
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None, enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0, get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
# Embeddings. # Embeddings.
if mpu.is_pipeline_first_stage(): if self.pre_process:
(input_ids, position_ids) = enc_language_model_input embedding_output = self.embedding(enc_input_ids, enc_position_ids,
embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
encoder_input = embedding_output encoder_input = embedding_output
else: else:
encoder_input = enc_language_model_input encoder_input = None
# encoder. # encoder.
if enc_hidden_states is None: if enc_hidden_states is None:
...@@ -353,7 +354,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -353,7 +354,7 @@ class TransformerLanguageModelBase(MegatronModule):
else: else:
encoder_output = enc_hidden_states.to(encoder_input.dtype) encoder_output = enc_hidden_states.to(encoder_input.dtype)
if mpu.is_pipeline_last_stage(): if self.post_process:
if self.add_pooler: if self.add_pooler:
pooled_output = self.pooler(encoder_output, pooled_output = self.pooler(encoder_output,
pooling_sequence_index) pooling_sequence_index)
...@@ -362,13 +363,12 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -362,13 +363,12 @@ class TransformerLanguageModelBase(MegatronModule):
# output. For example, it is helpful to compute # output. For example, it is helpful to compute
# similarity between two sequences by average pooling # similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden: if not self.add_decoder or output_enc_hidden:
if self.add_pooler and mpu.is_pipeline_last_stage(): if self.add_pooler and self.post_process:
return encoder_output, pooled_output return encoder_output, pooled_output
else: else:
return encoder_output return encoder_output
# Decoder Embedding # Decoder Embedding
(dec_input_ids, dec_position_ids) = dec_language_model_input
dec_embedding_output = self.embedding(dec_input_ids, dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids) dec_position_ids)
# decoder # decoder
...@@ -379,7 +379,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -379,7 +379,7 @@ class TransformerLanguageModelBase(MegatronModule):
encoder_output=encoder_output, encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask) enc_dec_attn_mask=enc_dec_attn_mask)
if self.add_pooler and mpu.is_pipeline_last_stage(): if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output return decoder_output, encoder_output, pooled_output
else: else:
return decoder_output, encoder_output return decoder_output, encoder_output
...@@ -389,14 +389,14 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -389,14 +389,14 @@ class TransformerLanguageModelBase(MegatronModule):
"""For easy load.""" """For easy load."""
state_dict_ = {} state_dict_ = {}
if mpu.is_pipeline_first_stage(): if self.pre_process:
state_dict_[self._embedding_key] \ state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint( = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._encoder_key] \ state_dict_[self._encoder_key] \
= self.encoder.state_dict_for_save_checkpoint( = self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage(): if self.post_process:
if self.add_pooler: if self.add_pooler:
state_dict_[self._pooler_key] \ state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint( = self.pooler.state_dict_for_save_checkpoint(
...@@ -412,7 +412,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -412,7 +412,7 @@ class TransformerLanguageModelBase(MegatronModule):
"""Customized load.""" """Customized load."""
# Embedding. # Embedding.
if mpu.is_pipeline_first_stage(): if self.pre_process:
if self._embedding_key in state_dict: if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key] state_dict_ = state_dict[self._embedding_key]
else: else:
...@@ -448,7 +448,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -448,7 +448,7 @@ class TransformerLanguageModelBase(MegatronModule):
self.encoder.load_state_dict(state_dict_, strict=strict) self.encoder.load_state_dict(state_dict_, strict=strict)
if mpu.is_pipeline_last_stage(): if self.post_process:
# pooler # pooler
if self.add_pooler: if self.add_pooler:
assert 'pooler' in state_dict, \ assert 'pooler' in state_dict, \
...@@ -461,124 +461,3 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -461,124 +461,3 @@ class TransformerLanguageModelBase(MegatronModule):
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
self.decoder.load_state_dict(state_dict[self._decoder_key], self.decoder.load_state_dict(state_dict[self._decoder_key],
strict=strict) strict=strict)
class TransformerLanguageModel(TransformerLanguageModelBase):
"""Transformer language model (see TransformerLanguageModelBase
for description of arguments).
"""
def __init__(self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
decoder_attn_mask_type=AttnMaskType.causal,
add_decoder=False,
add_pooler=False):
super(TransformerLanguageModel, self).__init__(
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler)
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
return super(TransformerLanguageModel, self).forward(
(enc_input_ids, enc_position_ids),
enc_attn_mask,
dec_language_model_input=(dec_input_ids, dec_position_ids),
dec_attn_mask=dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index,
enc_hidden_states=enc_hidden_states,
output_enc_hidden=output_enc_hidden
)
class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
"""Transformer language model, first stage (see
TransformerLanguageModelBase for description of arguments).
"""
def __init__(self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__(
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
return super(TransformerLanguageModelFirstStage, self).forward(
(input_ids, position_ids),
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value
)
class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
"""Transformer language model, intermediate stage (see
TransformerLanguageModelBase for description of arguments).
"""
def __init__(self,
init_method,
output_layer_init_method,
encoder_attn_mask_type):
super(TransformerLanguageModelIntermediateStage, self).__init__(
init_method,
output_layer_init_method,
encoder_attn_mask_type)
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False):
return super(TransformerLanguageModelIntermediateStage, self).forward(
hidden_states,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value
)
class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
"""Transformer language model, final stage (see
TransformerLanguageModelBase for description of arguments).
"""
def __init__(self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__(
init_method,
output_layer_init_method,
encoder_attn_mask_type,
add_pooler=add_pooler)
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False,
pooling_sequence_index=0):
return super(TransformerLanguageModelLastStage, self).forward(
hidden_states,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index,
)
...@@ -28,13 +28,18 @@ from megatron.model.utils import scaled_init_method_normal ...@@ -28,13 +28,18 @@ from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule from .module import MegatronModule
class MultipleChoiceBase(MegatronModule): class MultipleChoice(MegatronModule):
def __init__(self, num_tokentypes=2): def __init__(self,
super(MultipleChoiceBase, self).__init__(share_word_embeddings=False) num_tokentypes=2,
pre_process=True,
post_process=True):
super(MultipleChoice, self).__init__(share_word_embeddings=False)
args = get_args() args = get_args()
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
self.pre_process = pre_process
self.post_process = post_process
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
...@@ -42,15 +47,20 @@ class MultipleChoiceBase(MegatronModule): ...@@ -42,15 +47,20 @@ class MultipleChoiceBase(MegatronModule):
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)
# Multi-choice head. # Multi-choice head.
if mpu.is_pipeline_last_stage(): if self.post_process:
self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout) self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
self.multichoice_head = get_linear_layer(args.hidden_size, 1, self.multichoice_head = get_linear_layer(args.hidden_size, 1,
init_method) init_method)
self._multichoice_head_key = 'multichoice_head' self._multichoice_head_key = 'multichoice_head'
def set_input_tensor(self, input_tensor)
self.language_model.set_input_tensor(input_tensor)
def forward(self, model_input, attention_mask, tokentype_ids=None): def forward(self, model_input, attention_mask, tokentype_ids=None):
# [batch, choices, sequence] --> [batch * choices, sequence] --> # [batch, choices, sequence] --> [batch * choices, sequence] -->
...@@ -64,22 +74,21 @@ class MultipleChoiceBase(MegatronModule): ...@@ -64,22 +74,21 @@ class MultipleChoiceBase(MegatronModule):
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) attention_mask = attention_mask.view(-1, attention_mask.size(-1))
extended_attention_mask = bert_extended_attention_mask(attention_mask) extended_attention_mask = bert_extended_attention_mask(attention_mask)
kwargs = {} input_ids = model_input
if mpu.is_pipeline_first_stage(): # Do the same as attention_mask for input_ids, tokentype_ids
input_ids = model_input assert len(input_ids.shape) == 3
# Do the same as attention_mask for input_ids, tokentype_ids assert len(tokentype_ids.shape) == 3
assert len(input_ids.shape) == 3 input_ids = input_ids.view(-1, input_ids.size(-1))
assert len(tokentype_ids.shape) == 3 tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
input_ids = input_ids.view(-1, input_ids.size(-1)) position_ids = bert_position_ids(input_ids)
tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
lm_output = self.language_model(
position_ids = bert_position_ids(input_ids) input_ids,
args = [input_ids, position_ids, extended_attention_mask] position_ids,
kwargs['tokentype_ids'] = tokentype_ids extended_attention_mask,
else: tokentype_ids=tokentype_ids
args = [model_input, extended_attention_mask] )
lm_output = self.language_model(*args, **kwargs) if self.post_process:
if mpu.is_pipeline_last_stage():
_, pooled_output = lm_output _, pooled_output = lm_output
multichoice_output = self.multichoice_dropout(pooled_output) multichoice_output = self.multichoice_dropout(pooled_output)
multichoice_logits = self.multichoice_head(multichoice_output) multichoice_logits = self.multichoice_head(multichoice_output)
...@@ -99,7 +108,7 @@ class MultipleChoiceBase(MegatronModule): ...@@ -99,7 +108,7 @@ class MultipleChoiceBase(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage(): if self.post_process:
state_dict_[self._multichoice_head_key] \ state_dict_[self._multichoice_head_key] \
= self.multichoice_head.state_dict( = self.multichoice_head.state_dict(
destination, prefix, keep_vars) destination, prefix, keep_vars)
...@@ -110,7 +119,7 @@ class MultipleChoiceBase(MegatronModule): ...@@ -110,7 +119,7 @@ class MultipleChoiceBase(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
if mpu.is_pipeline_last_stage(): if self.post_process:
if self._multichoice_head_key in state_dict: if self._multichoice_head_key in state_dict:
self.multichoice_head.load_state_dict( self.multichoice_head.load_state_dict(
state_dict[self._multichoice_head_key], strict=strict) state_dict[self._multichoice_head_key], strict=strict)
...@@ -119,53 +128,3 @@ class MultipleChoiceBase(MegatronModule): ...@@ -119,53 +128,3 @@ class MultipleChoiceBase(MegatronModule):
'initializing to random'.format( 'initializing to random'.format(
self._multichoice_head_key)) self._multichoice_head_key))
class MultipleChoice(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoice, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(MultipleChoice, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class MultipleChoiceFirstStage(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoiceFirstStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(MultipleChoiceFirstStage, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class MultipleChoiceIntermediateStage(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoiceIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(MultipleChoiceIntermediateStage, self).forward(
hidden_state,
attention_mask)
class MultipleChoiceLastStage(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoiceLastStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(MultipleChoiceLastStage, self).forward(
hidden_state,
attention_mask)
...@@ -532,12 +532,16 @@ class ParallelTransformer(MegatronModule): ...@@ -532,12 +532,16 @@ class ParallelTransformer(MegatronModule):
def __init__(self, init_method, output_layer_init_method, def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding): self_attn_mask_type=AttnMaskType.padding,
pre_process=True, post_process=True):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args() args = get_args()
self.bf16 = args.bf16 self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection self.fp32_residual_connection = args.fp32_residual_connection
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations self.checkpoint_activations = args.checkpoint_activations
...@@ -580,7 +584,7 @@ class ParallelTransformer(MegatronModule): ...@@ -580,7 +584,7 @@ class ParallelTransformer(MegatronModule):
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
if mpu.is_pipeline_last_stage(): if self.post_process:
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
...@@ -615,6 +619,9 @@ class ParallelTransformer(MegatronModule): ...@@ -615,6 +619,9 @@ class ParallelTransformer(MegatronModule):
return hidden_states return hidden_states
def set_input_tensor(self, input_tensor):
self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False, encoder_output=None, enc_dec_attn_mask=None): get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
...@@ -628,7 +635,7 @@ class ParallelTransformer(MegatronModule): ...@@ -628,7 +635,7 @@ class ParallelTransformer(MegatronModule):
'get_key_value does not work with ' \ 'get_key_value does not work with ' \
'activation checkpointing' 'activation checkpointing'
if mpu.is_pipeline_first_stage(): if self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h]. # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float. # If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection: if self.fp32_residual_connection:
...@@ -636,10 +643,12 @@ class ParallelTransformer(MegatronModule): ...@@ -636,10 +643,12 @@ class ParallelTransformer(MegatronModule):
# Otherwise, leave it as is. # Otherwise, leave it as is.
else: else:
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
else:
hidden_states = self.input_tensor
if encoder_output is not None: if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous() encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.checkpoint_activations: if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
attention_mask, attention_mask,
...@@ -664,7 +673,7 @@ class ParallelTransformer(MegatronModule): ...@@ -664,7 +673,7 @@ class ParallelTransformer(MegatronModule):
presents.append(present) presents.append(present)
# Final layer norm. # Final layer norm.
if mpu.is_pipeline_last_stage(): if self.post_process:
# Reverting data format change [s b h] --> [b s h]. # Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
......
...@@ -34,8 +34,10 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -34,8 +34,10 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
timers = get_timers() timers = get_timers()
timers('forward-compute').start() timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor) model.module.module.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches() output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
......
...@@ -61,10 +61,10 @@ def print_datetime(string): ...@@ -61,10 +61,10 @@ def print_datetime(string):
print_rank_0('[' + string + '] datetime: {} '.format(time_str)) print_rank_0('[' + string + '] datetime: {} '.format(time_str))
def pretrain(train_valid_test_dataset_provider, def pretrain(train_valid_test_dataset_provider,
model_provider, model_provider,
forward_step_func, forward_step_func,
extra_args_provider=None, extra_args_provider=None,
args_defaults={}): args_defaults={}):
"""Main training program. """Main training program.
...@@ -196,7 +196,25 @@ def get_model(model_provider_func): ...@@ -196,7 +196,25 @@ def get_model(model_provider_func):
args = get_args() args = get_args()
# Build model on cpu. # Build model on cpu.
model = model_provider_func() pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
m = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
model.append(m)
else:
model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
if not isinstance(model, list): if not isinstance(model, list):
model = [model] model = [model]
...@@ -651,16 +669,16 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -651,16 +669,16 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if not saved_checkpoint: if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) lr_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time)) print_datetime('exiting program after {} minutes'.format(train_time))
sys.exit() sys.exit()
# Exiting based on iterations # Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint: if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) lr_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
print_datetime('exiting program at iteration {}'.format(iteration)) print_datetime('exiting program at iteration {}'.format(iteration))
sys.exit() sys.exit()
......
...@@ -17,56 +17,30 @@ ...@@ -17,56 +17,30 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import (BertModel, from megatron.model import BertModel
BertModelFirstStage,
BertModelIntermediateStage,
BertModelLastStage)
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
def model_provider(): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
print_rank_0('building BERT model ...') print_rank_0('building BERT model ...')
args = get_args() args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0 num_tokentypes = 2 if args.bert_binary_head else 0
def model_provider_pipelined(): model = BertModel(
# Determine model based on position of stage in pipeline. num_tokentypes=num_tokentypes,
if mpu.is_pipeline_first_stage(): add_binary_head=args.bert_binary_head,
model = BertModelFirstStage( parallel_output=True,
num_tokentypes=num_tokentypes) pre_process=pre_process,
elif mpu.is_pipeline_last_stage(): post_process=post_process)
model = BertModelLastStage(
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True)
else:
model = BertModelIntermediateStage(
num_tokentypes=num_tokentypes)
return model
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model.append(model_provider_pipelined())
else:
model = model_provider_pipelined()
else:
model = BertModel(
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True)
return model return model
...@@ -96,7 +70,33 @@ def get_batch(data_iterator): ...@@ -96,7 +70,33 @@ def get_batch(data_iterator):
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def forward_step(data_iterator, model, input_tensor): def loss_func(loss_mask, sentence_order, output_tensor):
lm_loss_, sop_logits = output_tensor
lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
if sop_logits is not None:
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
loss = lm_loss + sop_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss, sop_loss])
return loss, {'lm loss': averaged_losses[0],
'sop loss': averaged_losses[1]}
else:
loss = lm_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss])
return loss, {'lm loss': averaged_losses[0]}
def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -111,46 +111,10 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -111,46 +111,10 @@ def forward_step(data_iterator, model, input_tensor):
types = None types = None
# Forward pass through the model. # Forward pass through the model.
if mpu.is_pipeline_first_stage(): output_tensor = model(tokens, padding_mask, tokentype_ids=types,
assert input_tensor is None lm_labels=lm_labels)
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, padding_mask, tokentype_ids=types, return output_tensor, partial(loss_func, loss_mask, sentence_order)
lm_labels=lm_labels)
else:
output_tensor = model(tokens, padding_mask, tokentype_ids=types)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask)
if mpu.is_pipeline_last_stage():
lm_loss_, sop_logits = output_tensor
lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
if sop_logits is not None:
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
loss = lm_loss + sop_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss, sop_loss])
return loss, {'lm loss': averaged_losses[0],
'sop loss': averaged_losses[1]}
else:
loss = lm_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss])
return loss, {'lm loss': averaged_losses[0]}
return output_tensor
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
......
...@@ -16,50 +16,28 @@ ...@@ -16,50 +16,28 @@
"""Pretrain GPT""" """Pretrain GPT"""
import torch import torch
from functools import partial
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import (GPTModel, from megatron.model import GPTModel
GPTModelFirstStage,
GPTModelIntermediateStage,
GPTModelLastStage)
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
def model_provider(): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
print_rank_0('building GPT model ...') print_rank_0('building GPT model ...')
model = GPTModel(
def model_provider_pipelined(): num_tokentypes=0,
# Determine model based on position of stage in pipeline. parallel_output=True,
if mpu.is_pipeline_first_stage(): pre_process=pre_process,
model = GPTModelFirstStage(num_tokentypes=0) post_process=post_process
elif mpu.is_pipeline_last_stage(): )
model = GPTModelLastStage(
num_tokentypes=0, parallel_output=True)
else:
model = GPTModelIntermediateStage(
num_tokentypes=0)
return model
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model.append(model_provider_pipelined())
else:
model = model_provider_pipelined()
else:
model = GPTModel(num_tokentypes=0, parallel_output=True)
return model return model
...@@ -94,8 +72,18 @@ def get_batch(data_iterator): ...@@ -94,8 +72,18 @@ def get_batch(data_iterator):
return tokens, labels, loss_mask, attention_mask, position_ids return tokens, labels, loss_mask, attention_mask, position_ids
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
def forward_step(data_iterator, model, input_tensor): # Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -106,31 +94,10 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -106,31 +94,10 @@ def forward_step(data_iterator, model, input_tensor):
data_iterator) data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
# Forward pass through the model. output_tensor = model(tokens, position_ids, attention_mask,
if mpu.is_pipeline_first_stage(): labels=labels)
assert input_tensor is None
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
else:
output_tensor = model(tokens, position_ids, attention_mask)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask, labels=labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_pipeline_last_stage():
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]} return output_tensor, partial(loss_func, loss_mask)
return output_tensor
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment