Commit 46c74b4c authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Support for pipeline parallelism in T5 model

- Accumulate encoder hidden state gradient to handle skip connection
- Correctly compute the number of layers in encoder / decoder for T5 model
- Ensure e weights are initialized the same way in embeddings
- Synchronize embedding gradients across encoder and decoder for T5 model
- Support for checkpoint loading and saving
parent 6a680986
......@@ -80,6 +80,12 @@ def parse_args(extra_args_provider=None, defaults={},
args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True)
if args.pipeline_model_parallel_size > 1:
if args.pipeline_model_parallel_split_rank is not None:
assert args.pipeline_model_parallel_split_rank < \
args.pipeline_model_parallel_size, 'split rank needs'\
' to be less than pipeline model parallel size ({})'.format(
args.pipeline_model_parallel_size)
# Deprecated arguments
assert args.batch_size is None, '--batch-size argument is no longer ' \
......@@ -567,6 +573,9 @@ def _add_distributed_args(parser):
help='Degree of tensor model parallelism.')
group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.')
group.add_argument('--pipeline-model-parallel-split-rank',
type=int, default=None,
help='Rank where encoder and decoder should be split.')
group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.')
......
......@@ -193,7 +193,8 @@ def _initialize_distributed():
else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size)
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank)
def _init_autoresume():
......
......@@ -21,3 +21,4 @@ from .gpt_model import GPTModel
from .t5_model import T5Model
from .language_model import get_language_model
from .module import Float16Module
from .enums import ModelType
......@@ -15,6 +15,10 @@
import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
class LayerType(enum.Enum):
encoder = 1
decoder = 2
......
......@@ -45,7 +45,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(num_tokentypes, add_pooler,
encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False,
scaled_init_method=None, add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True, post_process=True):
"""Build language model and return along with the key to save."""
......@@ -64,6 +65,7 @@ def get_language_model(num_tokentypes, add_pooler,
scaled_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
add_encoder=add_encoder,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler,
......@@ -159,6 +161,13 @@ class Embedding(MegatronModule):
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.data.fill_(0)
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.data.fill_(0)
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
......@@ -273,6 +282,7 @@ class TransformerLanguageModel(MegatronModule):
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False,
......@@ -286,10 +296,12 @@ class TransformerLanguageModel(MegatronModule):
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.add_encoder = add_encoder
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
self.encoder_hidden_state = None
# Embeddings.
if self.pre_process:
......@@ -302,25 +314,33 @@ class TransformerLanguageModel(MegatronModule):
self._embedding_key = 'embedding'
# Transformer.
self.encoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process
)
self._encoder_key = 'encoder'
# Decoder
# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
if self.add_encoder:
self.encoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process
)
self._encoder_key = 'encoder'
else:
self.encoder = None
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if self.add_decoder:
assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type)
self_attn_mask_type=self.decoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process)
self._decoder_key = 'decoder'
else:
self.decoder = None
if self.post_process:
# Pooler.
......@@ -330,7 +350,25 @@ class TransformerLanguageModel(MegatronModule):
def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
self.encoder.set_input_tensor(input_tensor)
if self.add_encoder and self.add_decoder:
assert len(input_tensor) == 1, \
'input_tensor should only be length 1 for stage with both encoder and decoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_encoder:
assert len(input_tensor) == 1, \
'input_tensor should only be length 1 for stage with only encoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_decoder:
if len(input_tensor) == 2:
self.decoder.set_input_tensor(input_tensor[0])
self.encoder_hidden_state = input_tensor[1]
elif len(input_tensor) == 1:
self.decoder.set_input_tensor(None)
self.encoder_hidden_state = input_tensor[0]
else:
raise Exception('input_tensor must have either length 1 or 2')
else:
raise Exception('Stage must have at least either encoder or decoder')
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
......@@ -338,20 +376,22 @@ class TransformerLanguageModel(MegatronModule):
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
# Embeddings.
# Encoder embedding.
if self.pre_process:
embedding_output = self.embedding(enc_input_ids, enc_position_ids,
tokentype_ids=tokentype_ids)
encoder_input = embedding_output
encoder_input = self.embedding(enc_input_ids, enc_position_ids,
tokentype_ids=tokentype_ids)
else:
encoder_input = None
# encoder.
# Run encoder.
if enc_hidden_states is None:
encoder_output = self.encoder(encoder_input,
enc_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if self.encoder is not None:
encoder_output = self.encoder(encoder_input,
enc_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value)
else:
encoder_output = self.encoder_hidden_state
else:
encoder_output = enc_hidden_states.to(encoder_input.dtype)
......@@ -369,11 +409,15 @@ class TransformerLanguageModel(MegatronModule):
else:
return encoder_output
# Decoder Embedding
dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids)
# decoder
decoder_output = self.decoder(dec_embedding_output,
# Decoder embedding.
if self.pre_process:
decoder_input = self.embedding(dec_input_ids,
dec_position_ids)
else:
decoder_input = None
# Run decoder.
decoder_output = self.decoder(decoder_input,
dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
......@@ -394,9 +438,10 @@ class TransformerLanguageModel(MegatronModule):
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._encoder_key] \
= self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_encoder:
state_dict_[self._encoder_key] \
= self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.post_process:
if self.add_pooler:
state_dict_[self._pooler_key] \
......@@ -425,38 +470,39 @@ class TransformerLanguageModel(MegatronModule):
self.embedding.load_state_dict(state_dict_, strict=strict)
# Encoder.
if self._encoder_key in state_dict:
state_dict_ = state_dict[self._encoder_key]
# for backward compatibility.
elif 'transformer' in state_dict:
state_dict_ = state_dict['transformer']
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
# for backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
if self.add_encoder:
if self._encoder_key in state_dict:
state_dict_ = state_dict[self._encoder_key]
# For backward compatibility.
elif 'transformer' in state_dict:
state_dict_ = state_dict['transformer']
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
# For backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
# For backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
# Pooler.
if self.post_process:
# pooler
if self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
# decoder
# Decoder.
if self.add_decoder:
assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint'
......
......@@ -51,15 +51,14 @@ class MegatronModule(torch.nn.Module):
def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage(ignore_virtual=True):
if not mpu.is_pipeline_last_stage(ignore_virtual=True) or \
mpu.get_pipeline_model_parallel_world_size() == 1:
return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage(ignore_virtual=True):
else:
if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false')
return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be '
'called for first and last stage only')
def initialize_word_embeddings(self, init_method_normal):
......@@ -69,12 +68,12 @@ class MegatronModule(torch.nn.Module):
'share_word_embeddings is false')
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline
# parallelism there is nothing to do.
# when we are using pipeline parallelism. Nothing to do if we aren't
# using pipeline parallelism.
if args.pipeline_model_parallel_size == 1:
return
# Parameters are shared between the word embeddings layer, and the
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
......@@ -97,12 +96,34 @@ class MegatronModule(torch.nn.Module):
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
not mpu.is_pipeline_last_stage(ignore_virtual=True) and \
mpu.is_rank_in_embedding_group():
self.language_model.embedding.zero_parameters()
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
if mpu.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
# All-reduce other embeddings as well as necessary. The last stage
# does not have these other embeddings, so just create placeholder
# tensors of the right shape with all zeros.
# NOTE: We don't currently support T5 with the interleaved schedule.
if args.pipeline_model_parallel_split_rank is not None:
# TODO: Support tokentype embedding.
dimensions = (args.max_position_embeddings, args.hidden_size)
if mpu.is_pipeline_last_stage(ignore_virtual=True):
position_embeddings = torch.nn.Embedding(*dimensions).cuda()
position_embeddings.weight.data.fill_(0)
else:
self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight.data,
group=mpu.get_embedding_group())
else:
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
......
......@@ -86,7 +86,13 @@ class T5LMHead(MegatronModule):
class T5Model(MegatronModule):
"""T5 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True):
def __init__(self,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True,
add_encoder=True,
add_decoder=True):
super(T5Model, self).__init__()
args = get_args()
......@@ -95,19 +101,29 @@ class T5Model(MegatronModule):
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.pre_process = pre_process
self.post_process = post_process
self.add_encoder = add_encoder
self.add_decoder = add_decoder
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=False,
add_decoder=True,
add_encoder=add_encoder,
add_decoder=add_decoder,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
self.lm_head = T5LMHead(
self.language_model.embedding.word_embeddings.weight.size(0),
parallel_output)
self._lm_head_key = 'lm_head'
self.initialize_word_embeddings(init_method_normal)
if self.post_process and self.add_decoder:
self.lm_head = T5LMHead(
self.word_embeddings_weight().size(0),
parallel_output)
self._lm_head_key = 'lm_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
......@@ -134,22 +150,28 @@ class T5Model(MegatronModule):
tokentype_ids=tokentype_ids,
enc_hidden_states=enc_hidden_states)
decoder_output, encoder_output = lm_output
# Output.
lm_logits = self.lm_head(decoder_output,
self.language_model.embedding.word_embeddings.weight)
if self.post_process and self.add_decoder:
decoder_output, encoder_output = lm_output
# Output.
lm_logits = self.lm_head(decoder_output,
self.word_embeddings_weight())
if lm_labels is None:
return lm_logits, encoder_output
else:
if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
if lm_labels is None:
return lm_logits
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
return lm_loss, encoder_output
if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
return lm_loss
elif self.add_decoder and not self.add_encoder:
decoder_output, encoder_output = lm_output
return decoder_output
else:
encoder_output = lm_output
return encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......@@ -160,9 +182,14 @@ class T5Model(MegatronModule):
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.post_process and self.add_decoder:
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
# Save word_embeddings.
if self.post_process and not self.pre_process and self.add_decoder:
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
......@@ -170,5 +197,10 @@ class T5Model(MegatronModule):
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
self.lm_head.load_state_dict(state_dict[self._lm_head_key],
strict=strict)
if self.post_process and self.add_decoder:
self.lm_head.load_state_dict(state_dict[self._lm_head_key],
strict=strict)
# Load word embeddings.
if self.post_process and not self.pre_process and self.add_decoder:
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
......@@ -21,7 +21,7 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from .module import MegatronModule
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
......@@ -548,9 +548,8 @@ class ParallelTransformer(MegatronModule):
self.checkpoint_num_layers = args.checkpoint_num_layers
# Number of layers.
assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
'num_layers must be divisible by pipeline_model_parallel_size'
self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder)
# Transformer layers.
def build_layer(layer_number):
......
......@@ -31,6 +31,10 @@ from .initialize import get_pipeline_model_parallel_group
from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank
from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import is_rank_in_embedding_group
from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_split
from .initialize import is_pipeline_stage_at_split
from .initialize import get_num_layers
from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank
......
......@@ -34,6 +34,7 @@ _DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
......@@ -41,8 +42,11 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
# rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None
def is_unitialized():
......@@ -52,13 +56,19 @@ def is_unitialized():
def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=1,
virtual_pipeline_model_parallel_size_=None):
virtual_pipeline_model_parallel_size_=None,
pipeline_model_parallel_split_rank_=None):
"""
Initialize model data parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
rank in pipeline with split point.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
......@@ -101,6 +111,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
if pipeline_model_parallel_split_rank_ is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_
rank = torch.distributed.get_rank()
# Build the data-parallel groups.
......@@ -148,6 +162,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, \
'embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups):
......@@ -161,11 +176,18 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# first and last stages).
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
if pipeline_model_parallel_split_rank_ is not None and \
pipeline_model_parallel_split_rank_ not in embedding_ranks:
embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_],
ranks[-1]]
else:
embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
def model_parallel_is_initialized():
......@@ -268,6 +290,30 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def get_num_layers(args, is_encoder_and_decoder_model):
"""Compute the number of transformer layers resident on the current rank."""
if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None
num_ranks_in_encoder = args.pipeline_model_parallel_split_rank
num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder
assert args.num_layers % num_ranks_in_encoder == 0, \
'num_layers must be divisible by number of ranks given to encoder'
assert args.num_layers % num_ranks_in_decoder == 0, \
'num_layers must be divisible by number of ranks given to decoder'
if is_pipeline_stage_before_split():
num_layers = args.num_layers // num_ranks_in_encoder
else:
num_layers = args.num_layers // num_ranks_in_decoder
else:
assert args.num_layers % get_pipeline_model_parallel_world_size() == 0, \
'num_layers must be divisible by pipeline_model_parallel_size'
num_layers = args.num_layers // get_pipeline_model_parallel_world_size()
else:
num_layers = args.num_layers
return num_layers
def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
......@@ -290,6 +336,61 @@ def is_pipeline_last_stage(ignore_virtual=False):
get_pipeline_model_parallel_world_size() - 1)
def is_rank_in_embedding_group(ignore_virtual=False):
"""Return true if current rank is in embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _EMBEDDING_GLOBAL_RANKS
if ignore_virtual:
return rank in _EMBEDDING_GLOBAL_RANKS
if rank in _EMBEDDING_GLOBAL_RANKS:
if rank == _EMBEDDING_GLOBAL_RANKS[0]:
return is_pipeline_first_stage(ignore_virtual=False)
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
return is_pipeline_last_stage(ignore_virtual=False)
else:
return True
return False
def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_after_split(rank=None):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_at_split():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank = get_pipeline_model_parallel_rank()
return is_pipeline_stage_before_split(rank) and \
is_pipeline_stage_after_split(rank+1)
def get_virtual_pipeline_model_parallel_rank():
"""Return the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
......
......@@ -58,7 +58,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
grad = param.grad.detach()
if grad_not_none:
grad = param.grad.detach()
if grad_not_none:
# Make sure the grads are in fp32
assert param.grad.type() == 'torch.cuda.FloatTensor'
......
......@@ -173,7 +173,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a contihuous buffer
for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
......@@ -305,7 +305,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
if self.params_have_main_grad:
if self.params_have_main_grad and hasattr(model_param, 'main_grad'):
main_param.grad = model_param.main_grad.float()
else:
if model_param.grad is not None:
......
......@@ -22,8 +22,8 @@ from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False, tensor_shape=None,
override_scatter_gather_tensors_in_pipeline=False,
tensor_shape,
use_ring_exchange=False,
dtype_=None):
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
......@@ -37,16 +37,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
tensor_shape: shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
the same shape).
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
tensor_shape: optional, use when the input sequence contains less
tokens than the default sequence length
override_scatter_gather_tensors_in_pipeline: optional, this is used
when tensor_shape is
provided to overwide
scatter gather tensors
dtype_: optional, this is used when tensor_shape is provied and what
is the type of tensor_shape
dtype_: optional, this is used when the tensor that needs to be
communicated is different from args.params_dtype.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
......@@ -56,12 +53,15 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
if tensor_shape is None:
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
mpu.get_tensor_model_parallel_world_size()
override_scatter_gather_tensors_in_pipeline = False
if args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
if tensor_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0:
tensor_chunk_shape = tensor_chunk_shape // \
mpu.get_tensor_model_parallel_world_size()
else:
tensor_chunk_shape = tensor_shape
override_scatter_gather_tensors_in_pipeline = True
else:
tensor_chunk_shape = tensor_shape
dtype = args.params_dtype
......@@ -143,9 +143,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return tensor_recv_prev, tensor_recv_next
def recv_forward(tensor_shape=None,
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None, timers=None):
def recv_forward(tensor_shape, dtype_=None, timers=None):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage():
......@@ -159,15 +157,13 @@ def recv_forward(tensor_shape=None,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline=\
override_scatter_gather_tensors_in_pipeline,
dtype_=dtype_)
if timers is not None:
timers('forward-recv').stop()
return input_tensor
def recv_backward(timers=None):
def recv_backward(tensor_shape, timers=None):
"""Receive tensor from next rank in pipeline (backward receive)."""
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
......@@ -178,15 +174,14 @@ def recv_backward(timers=None):
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True)
recv_next=True,
tensor_shape=tensor_shape)
if timers is not None:
timers('backward-recv').stop()
return output_tensor_grad
def send_forward(output_tensor, timers=None,
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None):
def send_forward(output_tensor, tensor_shape, dtype_=None, timers=None):
"""Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage():
......@@ -197,14 +192,13 @@ def send_forward(output_tensor, timers=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
override_scatter_gather_tensors_in_pipeline=\
override_scatter_gather_tensors_in_pipeline,
tensor_shape=tensor_shape,
dtype_=dtype_)
if timers is not None:
timers('forward-send').stop()
def send_backward(input_tensor_grad, timers=None):
def send_backward(input_tensor_grad, tensor_shape, timers=None):
"""Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage():
if timers is not None:
......@@ -213,12 +207,13 @@ def send_backward(input_tensor_grad, timers=None):
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False)
recv_next=False,
tensor_shape=tensor_shape)
if timers is not None:
timers('backward-send').stop()
def send_forward_recv_backward(output_tensor, timers=None):
def send_forward_recv_backward(output_tensor, tensor_shape, timers=None):
"""Batched send and recv with next rank in pipeline."""
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
......@@ -229,13 +224,14 @@ def send_forward_recv_backward(output_tensor, timers=None):
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True)
recv_next=True,
tensor_shape=tensor_shape)
if timers is not None:
timers('forward-send-backward-recv').stop()
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, timers=None):
def send_backward_recv_forward(input_tensor_grad, tensor_shape, timers=None):
"""Batched send and recv with previous rank in pipeline."""
if mpu.is_pipeline_first_stage():
input_tensor = None
......@@ -246,13 +242,14 @@ def send_backward_recv_forward(input_tensor_grad, timers=None):
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False)
recv_next=False,
tensor_shape=tensor_shape)
if timers is not None:
timers('backward-send-forward-recv').stop()
return input_tensor
def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape, timers=None):
"""Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None:
timers('forward-send-forward-recv').start()
......@@ -260,13 +257,14 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False)
recv_next=False,
tensor_shape=tensor_shape)
if timers is not None:
timers('forward-send-forward-recv').stop()
return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape, timers=None):
"""Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None:
timers('backward-send-backward-recv').start()
......@@ -274,7 +272,8 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next)
recv_next=recv_next,
tensor_shape=tensor_shape)
if timers is not None:
timers('backward-send-backward-recv').stop()
return output_tensor_grad
......@@ -282,7 +281,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev,
recv_next, timers=None):
recv_next, tensor_shape, timers=None):
"""Batched send and recv with previous and next ranks in pipeline."""
if timers is not None:
timers('forward-backward-send-forward-backward-recv').start()
......@@ -290,7 +289,8 @@ def send_forward_backward_recv_forward_backward(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next)
recv_next=recv_next,
tensor_shape=tensor_shape)
if timers is not None:
timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad
......@@ -25,6 +25,8 @@ from megatron import p2p_communication
from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType
def get_forward_backward_func():
args = get_args()
......@@ -48,11 +50,18 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
passed-in input_tensor is used.
Returns output tensor."""
args = get_args()
timers = get_timers()
timers('forward-compute').start()
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True
unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage():
......@@ -62,7 +71,12 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
return output_tensor
if mpu.is_pipeline_stage_after_split() and \
args.model_type == ModelType.encoder_and_decoder:
return [output_tensor, input_tensor[-1]]
if unwrap_output_tensor:
return output_tensor
return [output_tensor]
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
......@@ -73,24 +87,53 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
Returns gradient of loss with respect to input tensor (None if first
stage)."""
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
args = get_args()
timers = get_timers()
timers('backward-compute').start()
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
unwrap_input_tensor_grad = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_input_tensor_grad = True
for x in input_tensor:
if x is not None:
x.retain_grad()
if not isinstance(output_tensor, list):
output_tensor = [output_tensor]
if not isinstance(output_tensor_grad, list):
output_tensor_grad = [output_tensor_grad]
# Backward pass.
if output_tensor_grad is None:
output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
if output_tensor_grad[0] is None:
output_tensor = optimizer.scale_loss(output_tensor[0])
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor.
input_tensor_grad = None
input_tensor_grad = [None]
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
input_tensor_grad = []
for x in input_tensor:
if x is None:
input_tensor_grad.append(None)
else:
input_tensor_grad.append(x.grad)
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
mpu.is_pipeline_stage_after_split() and \
args.model_type == ModelType.encoder_and_decoder:
if output_tensor_grad[1] is not None:
input_tensor_grad[-1].add_(output_tensor_grad[1])
if unwrap_input_tensor_grad:
input_tensor_grad = input_tensor_grad[0]
timers('backward-compute').stop()
......@@ -153,6 +196,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
args = get_args()
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
# Compute number of warmup and remaining microbatches.
num_model_chunks = len(model)
num_microbatches = get_num_microbatches() * num_model_chunks
......@@ -231,7 +277,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(
p2p_communication.recv_forward(timers=timers))
p2p_communication.recv_forward(tensor_shape, timers=timers))
for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k)
......@@ -260,12 +306,15 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape,
timers=timers)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else:
input_tensor = \
p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, timers=timers)
output_tensor, recv_prev=recv_prev,
tensor_shape=tensor_shape,
timers=timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
......@@ -329,7 +378,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
timers=timers)
tensor_shape=tensor_shape, timers=timers)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
......@@ -343,7 +392,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(timers=timers))
p2p_communication.recv_backward(tensor_shape, timers=timers))
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
......@@ -355,11 +404,107 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next, timers=timers))
input_tensor_grad, recv_next=recv_next,
tensor_shape=tensor_shape,
timers=timers))
return losses_reduced
def get_tensor_shapes(rank, model_type):
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
# first tensor is decoder (pre-transpose),
# second tensor is encoder (post-transpose).
# If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose).
args = get_args()
tensor_shapes = []
if model_type == ModelType.encoder_and_decoder:
if mpu.is_pipeline_stage_before_split(rank):
# If next rank is after split, then need transpose for encoder_hidden_state.
if mpu.is_pipeline_stage_before_split(rank+1):
tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
else:
tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
else:
tensor_shapes.append((args.decoder_seq_length, args.micro_batch_size, args.hidden_size))
tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
else:
tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
return tensor_shapes
def recv_forward(tensor_shapes, timers):
input_tensors = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
input_tensors.append(None)
else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape,
timers=timers))
return input_tensors
def recv_backward(tensor_shapes, timers):
output_tensor_grads = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
output_tensor_grads.append(None)
else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape,
timers=timers))
return output_tensor_grads
def send_forward(output_tensors, tensor_shapes, timers):
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_forward(output_tensor, tensor_shape, timers=timers)
def send_backward(input_tensor_grads, tensor_shapes, timers):
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_backward(input_tensor_grad, tensor_shape, timers=timers)
def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
output_tensor_grads = []
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
output_tensor_grads.append(None)
continue
output_tensor_grad = p2p_communication.send_forward_recv_backward(
output_tensor, tensor_shape, timers=timers)
output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads
def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
input_tensors = []
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
input_tensors.append(None)
continue
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape, timers=timers)
input_tensors.append(input_tensor)
return input_tensors
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
model, optimizer, timers,
forward_only):
......@@ -383,16 +528,23 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
model_type = unwrapped_model.model_type
rank = mpu.get_pipeline_model_parallel_rank()
recv_tensor_shapes = get_tensor_shapes(rank-1, model_type)
send_tensor_shapes = get_tensor_shapes(rank, model_type)
input_tensors = []
output_tensors = []
losses_reduced = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = p2p_communication.recv_forward(timers=timers)
input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
p2p_communication.send_forward(output_tensor, timers=timers)
send_forward(output_tensor, send_tensor_shapes, timers=timers)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
......@@ -401,7 +553,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_tensor = p2p_communication.recv_forward(timers=timers)
input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
......@@ -410,12 +562,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if forward_only:
p2p_communication.send_forward(output_tensor, timers=timers)
send_forward(output_tensor, send_tensor_shapes, timers=timers)
else:
output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor,
timers=timers)
send_forward_recv_backward(output_tensor,
send_tensor_shapes,
timers=timers)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
input_tensors.append(input_tensor)
......@@ -423,7 +575,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if forward_only:
if not last_iteration:
input_tensor = p2p_communication.recv_forward(timers=timers)
input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
else:
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
......@@ -433,11 +585,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if last_iteration:
input_tensor = None
p2p_communication.send_backward(input_tensor_grad, timers=timers)
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
else:
input_tensor = \
p2p_communication.send_backward_recv_forward(
input_tensor_grad, timers=timers)
send_backward_recv_forward(
input_tensor_grad, recv_tensor_shapes, timers=timers)
# Run cooldown backward passes.
if not forward_only:
......@@ -445,12 +597,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = p2p_communication.recv_backward(timers=timers)
output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
p2p_communication.send_backward(input_tensor_grad, timers=timers)
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
return losses_reduced
......@@ -38,6 +38,7 @@ from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.model import Float16Module
from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
......@@ -61,6 +62,7 @@ def print_datetime(string):
def pretrain(train_valid_test_dataset_provider,
model_provider,
model_type,
forward_step_func,
extra_args_provider=None,
args_defaults={}):
......@@ -77,6 +79,7 @@ def pretrain(train_valid_test_dataset_provider,
train/valid/test dataset and returns `train, valid, test` datasets.
model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
model_type: an enum that specifies the type of model being trained.
forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being
the info we would like to monitor during training, for example
......@@ -109,7 +112,8 @@ def pretrain(train_valid_test_dataset_provider,
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
model_type)
timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate '
'scheduler are built')
......@@ -189,13 +193,16 @@ def update_train_iters(args):
print_rank_0('setting training iterations to {}'.format(args.train_iters))
def get_model(model_provider_func):
def get_model(model_provider_func, model_type):
"""Build the model."""
args = get_args()
args.model_type = model_type
# Build model.
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None:
assert model_type != ModelType.encoder_and_decoder, \
"Interleaved schedule not supported for model with both encoder and decoder"
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
......@@ -206,14 +213,36 @@ def get_model(model_provider_func):
pre_process=pre_process,
post_process=post_process
)
this_model.model_type = model_type
model.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
add_encoder = True
add_decoder = True
if model_type == ModelType.encoder_and_decoder:
if mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.pipeline_model_parallel_split_rank is not None, \
"Split rank needs to be specified for model with both encoder and decoder"
rank = mpu.get_pipeline_model_parallel_rank()
split_rank = args.pipeline_model_parallel_split_rank
world_size = mpu.get_pipeline_model_parallel_world_size()
pre_process = rank == 0 or rank == split_rank
post_process = (rank == (split_rank - 1)) or (
rank == (world_size - 1))
add_encoder = mpu.is_pipeline_stage_before_split()
add_decoder = mpu.is_pipeline_stage_after_split()
model = model_provider_func(
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
else:
model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
model.model_type = model_type
if not isinstance(model, list):
model = [model]
......@@ -304,11 +333,11 @@ def get_learning_rate_scheduler(optimizer):
return lr_scheduler
def setup_model_and_optimizer(model_provider_func):
def setup_model_and_optimizer(model_provider_func, model_type):
"""Setup model and optimizer."""
args = get_args()
model = get_model(model_provider_func)
model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module))
......@@ -374,13 +403,14 @@ def train_step(forward_step_func, data_iterator,
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if (mpu.is_pipeline_first_stage(ignore_virtual=True) or
mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
......
......@@ -25,7 +25,7 @@ from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel
from megatron.model import BertModel, ModelType
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
......@@ -143,5 +143,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
pretrain(train_valid_test_datasets_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
......@@ -23,7 +23,7 @@ from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel
from megatron.model import GPTModel, ModelType
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group
......@@ -121,5 +121,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
pretrain(train_valid_test_datasets_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
......@@ -28,6 +28,7 @@ from megatron import get_timers
from megatron import mpu
from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ModelType
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
......@@ -174,5 +175,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider,
pretrain_ict_model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
......@@ -26,18 +26,22 @@ from megatron import (
print_rank_0
)
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import T5Model
from megatron.model import T5Model, ModelType
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
def model_provider(pre_process=True, post_process=True,
add_encoder=True, add_decoder=True):
"""Build the model."""
assert pre_process and post_process, "T5 doesn't yet support pipelining"
print_rank_0('building T5 model ...')
model = T5Model(num_tokentypes=0,
parallel_output=True)
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
return model
......@@ -70,9 +74,7 @@ def get_batch(data_iterator):
def loss_func(loss_mask, output_tensor):
lm_loss_, _ = output_tensor
lm_loss_ = lm_loss_.float()
lm_loss_ = output_tensor.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
......@@ -130,5 +132,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
pretrain(train_valid_test_datasets_provider, model_provider, ModelType.encoder_and_decoder,
forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
......@@ -20,6 +20,7 @@ import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType
from megatron.model.vit_model import VitModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
......@@ -92,6 +93,7 @@ if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'dataloader_type': 'cyclic'}
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment