Commit d439b417 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

update comment

parent 9da6e975
...@@ -109,9 +109,8 @@ class MegatronModule(torch.nn.Module): ...@@ -109,9 +109,8 @@ class MegatronModule(torch.nn.Module):
torch.distributed.all_reduce(self.word_embeddings_weight().data, torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group()) group=mpu.get_embedding_group())
# All-reduce other embeddings as well as necessary. The last stage # Ensure that encoder(first stage) and decoder(split stage) position
# does not have these other embeddings, so just create placeholder # embeddings have the same initial parameter values
# tensors of the right shape with all zeros.
# NOTE: We don't currently support T5 with the interleaved schedule. # NOTE: We don't currently support T5 with the interleaved schedule.
if mpu.is_rank_in_position_embedding_group() and \ if mpu.is_rank_in_position_embedding_group() and \
args.pipeline_model_parallel_split_rank is not None: args.pipeline_model_parallel_split_rank is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment