Commit 64b9d28c authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

renamed reduce_grads/gather_params -> reduce_model_grads/gather_model_params

parent d58d1762
...@@ -381,21 +381,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -381,21 +381,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return gbuf_view_items return gbuf_view_items
# def reduce_grads(self, model): def reduce_model_grads(self, args, timers):
def reduce_grads(self, args, timers): '''Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding
# >>> grads.
# from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP '''
# from megatron import get_args
# from megatron import get_timers
# from megatron.model import DistributedDataParallel as LocalDDP
# from megatron.model import Float16Module
# from megatron.utils import unwrap_model
# args = get_args()
# timers = get_timers()
# <<<
# All-reduce embedding grads. # All-reduce embedding grads.
timers('backward-embedding-all-reduce').start() timers('backward-embedding-all-reduce').start()
...@@ -420,7 +410,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -420,7 +410,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers('backward-params-all-reduce').stop() timers('backward-params-all-reduce').stop()
def gather_params(self, args, timers, ITERATION): def gather_model_params(self, args, timers, ITERATION):
# >>> # >>>
# timers = get_timers() # timers = get_timers()
...@@ -454,8 +444,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -454,8 +444,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# pax(0, {"gbuf_view_items": gbuf_view_items}) # pax(0, {"gbuf_view_items": gbuf_view_items})
# >>> # >>>
# self.debug_main(ITERATION, "after/inside gather_params.", 0) # self.debug_main(ITERATION, "after/inside gather_model_params.", 0)
# self.debug_model(ITERATION, "after/inside gather_params.", 0) # self.debug_model(ITERATION, "after/inside gather_model_params.", 0)
# if ITERATION == 2: # if ITERATION == 2:
# pax(1, { # pax(1, {
......
...@@ -180,7 +180,9 @@ class MegatronOptimizer(ABC): ...@@ -180,7 +180,9 @@ class MegatronOptimizer(ABC):
def step(self, args, timers): def step(self, args, timers):
pass pass
def gather_params(self, args, timers, ITERATION): def gather_model_params(self, args, timers, ITERATION):
'''For the case of a non-distributed-optimizer, there is nothing to
do here.'''
pass pass
def allreduce_word_embedding_grads(self): def allreduce_word_embedding_grads(self):
...@@ -236,8 +238,7 @@ class MegatronOptimizer(ABC): ...@@ -236,8 +238,7 @@ class MegatronOptimizer(ABC):
self.allreduce_word_embedding_grads() self.allreduce_word_embedding_grads()
self.allreduce_position_embedding_grads() self.allreduce_position_embedding_grads()
# def reduce_grads(self, model): def reduce_model_grads(self, args, timers):
def reduce_grads(self, args, timers):
# pax(0, { # pax(0, {
# "*models" : self.models, # "*models" : self.models,
......
...@@ -437,7 +437,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -437,7 +437,7 @@ def train_step(forward_step_func, data_iterator,
# >>> # >>>
# Reduce gradients. # Reduce gradients.
optimizer.reduce_grads(args, timers) optimizer.reduce_model_grads(args, timers)
# <<< # <<<
# Update parameters. # Update parameters.
...@@ -447,7 +447,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -447,7 +447,7 @@ def train_step(forward_step_func, data_iterator,
# >>> # >>>
# Gather params. # Gather params.
optimizer.gather_params(args, timers, ITERATION) optimizer.gather_model_params(args, timers, ITERATION)
# <<< # <<<
# >>> # >>>
...@@ -464,7 +464,9 @@ def train_step(forward_step_func, data_iterator, ...@@ -464,7 +464,9 @@ def train_step(forward_step_func, data_iterator,
else: else:
skipped_iter = 1 skipped_iter = 1
# Empty unused memory # >>>
# Empty unused memory.
# <<<
if args.empty_unused_memory_level >= 2: if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment