Commit e5bda3c9 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

working w/ mixed parallelism [ dp, tp, pp ].

parent 2f99b4f4
......@@ -598,12 +598,6 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
timers = get_timers()
# <<<
# >>>
# pax(0, {
# "grads" : [ p.main_grad for m in model for p in m.parameters() ],
# })
# <<<
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
......@@ -611,12 +605,6 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# >>>
# pax(0, {
# "grads" : [ p.main_grad for m in model for p in m.parameters() ],
# })
# <<<
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
......@@ -1246,6 +1234,15 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def reduce_grads(self, model):
# >>>
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_timers
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.utils import unwrap_model
args = get_args()
timers = get_timers()
# <<<
......@@ -1262,7 +1259,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
# >>>
raise Exception("[fix] ready for weight sync?")
# raise Exception("[fix] ready for weight sync?")
# <<<
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment