Commit e46230dc authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

moved 'reduce_grads()' to MegatronOptimizer.

parent 772a4a2d
...@@ -124,21 +124,6 @@ class MegatronOptimizer(ABC): ...@@ -124,21 +124,6 @@ class MegatronOptimizer(ABC):
return self.get_loss_scale() * loss return self.get_loss_scale() * loss
@abstractmethod
def reduce_grads(self):
pass
@abstractmethod
def step(self):
pass
@abstractmethod
def gather_params(self):
pass
@abstractmethod @abstractmethod
def reload_model_params(self): def reload_model_params(self):
"""Refreshes any internal state from the current model parameters. """Refreshes any internal state from the current model parameters.
...@@ -182,6 +167,80 @@ class MegatronOptimizer(ABC): ...@@ -182,6 +167,80 @@ class MegatronOptimizer(ABC):
param_groups = property(_get_param_groups, _set_param_groups) param_groups = property(_get_param_groups, _set_param_groups)
@abstractmethod
def step(self):
pass
def gather_params(self):
pass
def reduce_grads(self, model):
# >>>
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_timers
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.utils import unwrap_model
args = get_args()
timers = get_timers()
# <<<
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
for model_module in model:
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
# >>>
# raise Exception("[main] ready for weight sync?")
# <<<
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
# >>>
raise Exception("[main] ready for t5 sync?")
# <<<
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
# class BaseFloat16Optimizer(MegatronOptimizer): # class BaseFloat16Optimizer(MegatronOptimizer):
class MixedPrecisionOptimizer(MegatronOptimizer): class MixedPrecisionOptimizer(MegatronOptimizer):
...@@ -251,15 +310,9 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -251,15 +310,9 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
main_grads, self.found_inf, self.grad_scaler.inv_scale) main_grads, self.found_inf, self.grad_scaler.inv_scale)
# Update across all model parallel instances. # Update across all model parallel instances.
if args.use_# >>>
# torch.distributed.all_reduce(self.found_inf,
# op=torch.distributed.ReduceOp.MAX,
# group=mpu.get_model_parallel_group())
# +++
torch.distributed.all_reduce(self.found_inf, torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=self.get_model_parallel_group()) group=self.get_model_parallel_group())
# <<<
# Check for nan. # Check for nan.
found_inf_flag = (self.found_inf.item() > 0) found_inf_flag = (self.found_inf.item() > 0)
...@@ -267,58 +320,58 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -267,58 +320,58 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
return found_inf_flag return found_inf_flag
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
@classmethod # @classmethod
def debug_base(cls, ITERATION, key, value): # def debug_base(cls, ITERATION, key, value):
from megatron import get_args # from megatron import get_args
args = get_args() # args = get_args()
my_rank = torch.distributed.get_rank() # my_rank = torch.distributed.get_rank()
if ITERATION != DEBUG_ITERATION: # if ITERATION != DEBUG_ITERATION:
return # return
for r in range(torch.distributed.get_world_size()): # for r in range(torch.distributed.get_world_size()):
if my_rank == r: # if my_rank == r:
print(" + br/%s; [r%d, i%d]; %s, %.12e" % ("fix " if args.use_distributed_optimizer else "main", my_rank, ITERATION, key, value)) # print(" + br/%s; [r%d, i%d]; %s, %.12e" % ("fix " if args.use_distributed_optimizer else "main", my_rank, ITERATION, key, value))
torch.distributed.barrier() # torch.distributed.barrier()
torch.distributed.barrier() # torch.distributed.barrier()
# if my_rank == 0: # # if my_rank == 0:
# raise Exception("debug.") # # raise Exception("debug.")
# else: # # else:
# exit(0) # # exit(0)
exit(0) # exit(0)
def debug_model(self, ITERATION, key, use_grad): # def debug_model(self, ITERATION, key, use_grad):
use_grad = bool(use_grad) # use_grad = bool(use_grad)
tensors = [ # tensors = [
(p.main_grad.float() if use_grad else p.float()) # (p.main_grad.float() if use_grad else p.float())
for m in self.models for p in m.parameters() # for m in self.models for p in m.parameters()
] # ]
count = sum(t.nelement() for t in tensors) # count = sum(t.nelement() for t in tensors)
return self.debug_base( # return self.debug_base(
ITERATION, # ITERATION,
"model/%s, %s [count %d]" % ( # "model/%s, %s [count %d]" % (
"grad" if use_grad else "param", # "grad" if use_grad else "param",
key, # key,
count, # count,
), # ),
# sum(torch.sum(torch.abs(t)) for t in tensors).item() / count, # # sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
sum(torch.sum(torch.abs(t)) for t in tensors), # sum(torch.sum(torch.abs(t)) for t in tensors),
) # )
def debug_main(self, ITERATION, key, use_grad): # def debug_main(self, ITERATION, key, use_grad):
use_grad = bool(use_grad) # use_grad = bool(use_grad)
tensors = [ # tensors = [
p.grad if use_grad else p # p.grad if use_grad else p
for g in self.optimizer.param_groups # for g in self.optimizer.param_groups
for p in g["params"] # for p in g["params"]
] # ]
tensors = [ t.float() for t in tensors ] # tensors = [ t.float() for t in tensors ]
count = sum(t.nelement() for t in tensors) # count = sum(t.nelement() for t in tensors)
return self.debug_base( # return self.debug_base(
ITERATION, # ITERATION,
"main/%s, %s [count %d]" % ( # "main/%s, %s [count %d]" % (
"grad" if use_grad else "param", # "grad" if use_grad else "param",
key, # key,
count, # count,
), # ),
sum(torch.sum(torch.abs(t)) for t in tensors), # sum(torch.sum(torch.abs(t)) for t in tensors),
) # )
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
@torch.no_grad() @torch.no_grad()
...@@ -327,10 +380,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -327,10 +380,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers = get_timers() timers = get_timers()
# >>> # >>>
# self.debug_model_param(ITERATION, "before copy grad.") # self.debug_model(ITERATION, "before copy grad.", 0)
# self.debug_model_grad(ITERATION, "before copy grad.") # self.debug_main(ITERATION, "before copy grad.", 0)
# self.debug_main_param(ITERATION, "before copy grad.")
# self.debug_main_grad(ITERATION, "before copy grad.")
# <<< # <<<
# Copy gradients from model params to main params. # Copy gradients from model params to main params.
...@@ -338,11 +389,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -338,11 +389,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self._copy_model_grads_to_main_grads(ITERATION) self._copy_model_grads_to_main_grads(ITERATION)
timers('optimizer-copy-to-main-grad').stop() timers('optimizer-copy-to-main-grad').stop()
# >>>
# self.debug_model(ITERATION, "after copy grad.", 0)
# self.debug_main(ITERATION, "after copy grad.", 1)
# <<<
# Do unscale, check for inf, and update grad scaler only for # Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided. # the case that grad scaler is provided.
if self.grad_scaler: if self.grad_scaler:
...@@ -358,11 +404,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -358,11 +404,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# If we found inf/nan, skip the update. # If we found inf/nan, skip the update.
if found_inf_flag: if found_inf_flag:
pax(0, {
"main params" : self.get_main_params(),
"main grads" : self.get_main_grads(),
"found_inf_flag" : found_inf_flag,
})
return False, None, None return False, None, None
# Clip the main gradients. # Clip the main gradients.
...@@ -376,41 +417,21 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -376,41 +417,21 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
num_zeros_in_grad = self.count_zeros() if \ num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None self.log_num_zeros_in_grad else None
# >>>
# param = self.optimizer.param_groups[0]["params"][0]
# pax(0, {
# "param" : tp(param),
# "grad" : tp(param.grad),
# })
# <<<
# >>>
# self.debug_main(ITERATION, "before step.", 0)
# <<<
# Step the optimizer. # Step the optimizer.
self.optimizer.step() self.optimizer.step()
# >>>
# self.debug_main(ITERATION, "after step.", 0)
# <<<
# Update params from main params. # Update params from main params.
timers('optimizer-copy-main-to-model-params').start() timers('optimizer-copy-main-to-model-params').start()
self._copy_main_params_to_model_params(ITERATION) self._copy_main_params_to_model_params(ITERATION)
timers('optimizer-copy-main-to-model-params').stop() timers('optimizer-copy-main-to-model-params').stop()
# >>>
# self.debug_main_param(ITERATION, "after copy param.")
# self.debug_main_grad(ITERATION, "after copy param.")
# <<<
# Successful update. # Successful update.
return True, grad_norm, num_zeros_in_grad return True, grad_norm, num_zeros_in_grad
# class Float16OptimizerWithFloat16Params(MegatronOptimizer): # class Float16OptimizerWithFloat16Params(MegatronOptimizer):
class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): # class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
"""Float16 optimizer for fp16 and bf16 data types. """Float16 optimizer for fp16 and bf16 data types.
Arguments: Arguments:
...@@ -482,17 +503,11 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -482,17 +503,11 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
fp32_from_float16_params_this_group.append(main_param) fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param. # Reset existing state dict key to the new main param.
if param in self.optimizer.state: if param in self.optimizer.state:
# >>>
raise Exception("hi.")
# <<<
self.optimizer.state[main_param] \ self.optimizer.state[main_param] \
= self.optimizer.state.pop(param) = self.optimizer.state.pop(param)
# fp32 params. # fp32 params.
elif param.type() == 'torch.cuda.FloatTensor': elif param.type() == 'torch.cuda.FloatTensor':
# >>>
pax(0, {"param": param})
# <<<
fp32_params_this_group.append(param) fp32_params_this_group.append(param)
param_group['params'][i] = param param_group['params'][i] = param
...@@ -512,19 +527,9 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -512,19 +527,9 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# recast preexisting per-param state tensors # recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
# >>>
# from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
# params = self.get_parameters()
# pax(0, {
# # "params / 0" : params[0],
# "params" : [ (p.tensor_model_parallel, tp(p)) for p in params ],
# "grads" : [ (param_is_not_tensor_parallel_duplicate(p.grad), tp(p.grad)) for p in params ],
# })
# <<<
def get_model_parallel_group(self): def get_model_parallel_group(self):
return mpu.get_model_parallel_group()) return mpu.get_model_parallel_group()
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
...@@ -541,76 +546,35 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -541,76 +546,35 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
# >>> def _collect_main_grad_data_for_unscaling(self):
def reduce_grads(self, model):
# >>>
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_timers
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.utils import unwrap_model
args = get_args() main_grads = []
timers = get_timers()
# <<<
# All-reduce if needed. # fp32 params from float16 ones.
if args.DDP_impl == 'local': for main_group in self.fp32_from_float16_groups:
timers('backward-params-all-reduce').start() for main_param in main_group:
for model_module in model: if main_param.grad is not None:
model_module.allreduce_gradients() main_grads.append(main_param.grad.data)
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure # Append fp32 parameters.
# that word_embeddings parameters stay in sync. for main_group in self.fp32_from_fp32_groups:
# This should only run for models that support pipelined model parallelism for main_param in main_group:
# (BERT and GPT-2). if main_param.grad is not None:
timers('backward-embedding-all-reduce').start() main_grads.append(main_param.grad.data)
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1: return main_grads
# >>>
# raise Exception("[main] ready for weight sync?")
# <<<
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder) def _get_model_and_main_params_data_float16(self):
# stages to ensure that position embeddings parameters stay in sync. model_data = []
# This should only run for T5 models with pipeline parallelism main_data = []
if mpu.is_rank_in_position_embedding_group() and \ for model_group, main_group in zip(self.float16_groups,
mpu.get_pipeline_model_parallel_world_size() > 1 and \ self.fp32_from_float16_groups):
args.pipeline_model_parallel_split_rank is not None: for model_param, main_param in zip(model_group, main_group):
# >>> model_data.append(model_param.data)
raise Exception("[main] ready for t5 sync?") main_data.append(main_param.data)
# <<< return model_data, main_data
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
def gather_params(self, ITERATION):
pass
def _copy_model_grads_to_main_grads(self, ITERATION): def _copy_model_grads_to_main_grads(self, ITERATION):
# This only needs to be done for the float16 group. # This only needs to be done for the float16 group.
...@@ -653,49 +617,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -653,49 +617,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# }) # })
# <<< # <<<
def _collect_main_grad_data_for_unscaling(self):
main_grads = []
# fp32 params from float16 ones.
for main_group in self.fp32_from_float16_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# Append fp32 parameters.
for main_group in self.fp32_from_fp32_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
return main_grads
def _get_model_and_main_params_data_float16(self):
model_data = []
main_data = []
for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
return model_data, main_data
def _copy_main_params_to_model_params(self, ITERATION): def _copy_main_params_to_model_params(self, ITERATION):
# Only needed for the float16 params. # Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16() model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(this=main_data, that=model_data, _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
overflow_buf=self._dummy_overflow_buf) overflow_buf=self._dummy_overflow_buf)
# >>>
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "** branch **" : "** main. **",
# "ITERATION" : ITERATION,
# "model params" : [p for m in self.models for p in m.parameters()],
# })
# <<<
def _copy_model_params_to_main_params(self): def _copy_model_params_to_main_params(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment