Commit 2f99b4f4 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

copied t5 embed sync to distrib opt

parent f8fd34e3
...@@ -625,7 +625,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -625,7 +625,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \ if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1: mpu.get_pipeline_model_parallel_world_size() > 1:
# >>> # >>>
raise Exception("hi.") # raise Exception("[main] ready for weight sync?")
# <<< # <<<
if mpu.is_pipeline_first_stage(ignore_virtual=True): if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0] unwrapped_model = model[0]
...@@ -656,6 +656,9 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -656,6 +656,9 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
if mpu.is_rank_in_position_embedding_group() and \ if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \ mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None: args.pipeline_model_parallel_split_rank is not None:
# >>>
raise Exception("[main] ready for t5 sync?")
# <<<
unwrapped_model = model[0] unwrapped_model = model[0]
unwrapped_model = unwrap_model( unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model, (torchDDP, LocalDDP, Float16Module))
...@@ -1259,7 +1262,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1259,7 +1262,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \ if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1: mpu.get_pipeline_model_parallel_world_size() > 1:
# >>> # >>>
raise Exception("hi.") raise Exception("[fix] ready for weight sync?")
# <<< # <<<
if mpu.is_pipeline_first_stage(ignore_virtual=True): if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0] unwrapped_model = model[0]
...@@ -1284,13 +1287,37 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1284,13 +1287,37 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# torch.distributed.all_reduce(grad_shard, # torch.distributed.all_reduce(grad_shard,
# group=mpu.get_embedding_group()) # group=mpu.get_embedding_group())
# <<< # <<<
timers('backward-embedding-all-reduce').stop()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync T5 position embedding params. # Sync T5 position embedding params.
# ... todo ... # ... todo ...
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
# >>>
raise Exception("[fix] ready for t5 sync?")
# <<<
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
# >>>
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
# +++
# grad_shard = optimizer.get_grad_shard(
# unwrapped_model.language_model.embedding.position_embeddings.weight)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_position_embedding_group())
# <<<
timers('backward-embedding-all-reduce').stop()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter. # Reduce-scatter.
# timers('backward-params-reduce-scatter').start() # timers('backward-params-reduce-scatter').start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment