Commit e46230dc authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

moved 'reduce_grads()' to MegatronOptimizer.

parent 772a4a2d
......@@ -124,21 +124,6 @@ class MegatronOptimizer(ABC):
return self.get_loss_scale() * loss
@abstractmethod
def reduce_grads(self):
pass
@abstractmethod
def step(self):
pass
@abstractmethod
def gather_params(self):
pass
@abstractmethod
def reload_model_params(self):
"""Refreshes any internal state from the current model parameters.
......@@ -182,6 +167,80 @@ class MegatronOptimizer(ABC):
param_groups = property(_get_param_groups, _set_param_groups)
@abstractmethod
def step(self):
pass
def gather_params(self):
pass
def reduce_grads(self, model):
# >>>
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_timers
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.utils import unwrap_model
args = get_args()
timers = get_timers()
# <<<
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
for model_module in model:
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
# >>>
# raise Exception("[main] ready for weight sync?")
# <<<
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
# >>>
raise Exception("[main] ready for t5 sync?")
# <<<
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
# class BaseFloat16Optimizer(MegatronOptimizer):
class MixedPrecisionOptimizer(MegatronOptimizer):
......@@ -251,15 +310,9 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
main_grads, self.found_inf, self.grad_scaler.inv_scale)
# Update across all model parallel instances.
if args.use_# >>>
# torch.distributed.all_reduce(self.found_inf,
# op=torch.distributed.ReduceOp.MAX,
# group=mpu.get_model_parallel_group())
# +++
torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX,
group=self.get_model_parallel_group())
# <<<
# Check for nan.
found_inf_flag = (self.found_inf.item() > 0)
......@@ -267,58 +320,58 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
return found_inf_flag
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
@classmethod
def debug_base(cls, ITERATION, key, value):
from megatron import get_args
args = get_args()
my_rank = torch.distributed.get_rank()
if ITERATION != DEBUG_ITERATION:
return
for r in range(torch.distributed.get_world_size()):
if my_rank == r:
print(" + br/%s; [r%d, i%d]; %s, %.12e" % ("fix " if args.use_distributed_optimizer else "main", my_rank, ITERATION, key, value))
torch.distributed.barrier()
torch.distributed.barrier()
# if my_rank == 0:
# raise Exception("debug.")
# else:
# exit(0)
exit(0)
def debug_model(self, ITERATION, key, use_grad):
use_grad = bool(use_grad)
tensors = [
(p.main_grad.float() if use_grad else p.float())
for m in self.models for p in m.parameters()
]
count = sum(t.nelement() for t in tensors)
return self.debug_base(
ITERATION,
"model/%s, %s [count %d]" % (
"grad" if use_grad else "param",
key,
count,
),
# sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
sum(torch.sum(torch.abs(t)) for t in tensors),
)
def debug_main(self, ITERATION, key, use_grad):
use_grad = bool(use_grad)
tensors = [
p.grad if use_grad else p
for g in self.optimizer.param_groups
for p in g["params"]
]
tensors = [ t.float() for t in tensors ]
count = sum(t.nelement() for t in tensors)
return self.debug_base(
ITERATION,
"main/%s, %s [count %d]" % (
"grad" if use_grad else "param",
key,
count,
),
sum(torch.sum(torch.abs(t)) for t in tensors),
)
# @classmethod
# def debug_base(cls, ITERATION, key, value):
# from megatron import get_args
# args = get_args()
# my_rank = torch.distributed.get_rank()
# if ITERATION != DEBUG_ITERATION:
# return
# for r in range(torch.distributed.get_world_size()):
# if my_rank == r:
# print(" + br/%s; [r%d, i%d]; %s, %.12e" % ("fix " if args.use_distributed_optimizer else "main", my_rank, ITERATION, key, value))
# torch.distributed.barrier()
# torch.distributed.barrier()
# # if my_rank == 0:
# # raise Exception("debug.")
# # else:
# # exit(0)
# exit(0)
# def debug_model(self, ITERATION, key, use_grad):
# use_grad = bool(use_grad)
# tensors = [
# (p.main_grad.float() if use_grad else p.float())
# for m in self.models for p in m.parameters()
# ]
# count = sum(t.nelement() for t in tensors)
# return self.debug_base(
# ITERATION,
# "model/%s, %s [count %d]" % (
# "grad" if use_grad else "param",
# key,
# count,
# ),
# # sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
# sum(torch.sum(torch.abs(t)) for t in tensors),
# )
# def debug_main(self, ITERATION, key, use_grad):
# use_grad = bool(use_grad)
# tensors = [
# p.grad if use_grad else p
# for g in self.optimizer.param_groups
# for p in g["params"]
# ]
# tensors = [ t.float() for t in tensors ]
# count = sum(t.nelement() for t in tensors)
# return self.debug_base(
# ITERATION,
# "main/%s, %s [count %d]" % (
# "grad" if use_grad else "param",
# key,
# count,
# ),
# sum(torch.sum(torch.abs(t)) for t in tensors),
# )
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
@torch.no_grad()
......@@ -327,10 +380,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers = get_timers()
# >>>
# self.debug_model_param(ITERATION, "before copy grad.")
# self.debug_model_grad(ITERATION, "before copy grad.")
# self.debug_main_param(ITERATION, "before copy grad.")
# self.debug_main_grad(ITERATION, "before copy grad.")
# self.debug_model(ITERATION, "before copy grad.", 0)
# self.debug_main(ITERATION, "before copy grad.", 0)
# <<<
# Copy gradients from model params to main params.
......@@ -338,11 +389,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self._copy_model_grads_to_main_grads(ITERATION)
timers('optimizer-copy-to-main-grad').stop()
# >>>
# self.debug_model(ITERATION, "after copy grad.", 0)
# self.debug_main(ITERATION, "after copy grad.", 1)
# <<<
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if self.grad_scaler:
......@@ -358,11 +404,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# If we found inf/nan, skip the update.
if found_inf_flag:
pax(0, {
"main params" : self.get_main_params(),
"main grads" : self.get_main_grads(),
"found_inf_flag" : found_inf_flag,
})
return False, None, None
# Clip the main gradients.
......@@ -376,41 +417,21 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
# >>>
# param = self.optimizer.param_groups[0]["params"][0]
# pax(0, {
# "param" : tp(param),
# "grad" : tp(param.grad),
# })
# <<<
# >>>
# self.debug_main(ITERATION, "before step.", 0)
# <<<
# Step the optimizer.
self.optimizer.step()
# >>>
# self.debug_main(ITERATION, "after step.", 0)
# <<<
# Update params from main params.
timers('optimizer-copy-main-to-model-params').start()
self._copy_main_params_to_model_params(ITERATION)
timers('optimizer-copy-main-to-model-params').stop()
# >>>
# self.debug_main_param(ITERATION, "after copy param.")
# self.debug_main_grad(ITERATION, "after copy param.")
# <<<
# Successful update.
return True, grad_norm, num_zeros_in_grad
# class Float16OptimizerWithFloat16Params(MegatronOptimizer):
class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
"""Float16 optimizer for fp16 and bf16 data types.
Arguments:
......@@ -482,17 +503,11 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
if param in self.optimizer.state:
# >>>
raise Exception("hi.")
# <<<
self.optimizer.state[main_param] \
= self.optimizer.state.pop(param)
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
# >>>
pax(0, {"param": param})
# <<<
fp32_params_this_group.append(param)
param_group['params'][i] = param
......@@ -512,19 +527,9 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
# >>>
# from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
# params = self.get_parameters()
# pax(0, {
# # "params / 0" : params[0],
# "params" : [ (p.tensor_model_parallel, tp(p)) for p in params ],
# "grads" : [ (param_is_not_tensor_parallel_duplicate(p.grad), tp(p.grad)) for p in params ],
# })
# <<<
def get_model_parallel_group(self):
return mpu.get_model_parallel_group())
return mpu.get_model_parallel_group()
def zero_grad(self, set_to_none=True):
......@@ -541,76 +546,35 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
_zero_grad_group_helper(group, set_to_none)
# >>>
def reduce_grads(self, model):
# >>>
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_timers
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.utils import unwrap_model
def _collect_main_grad_data_for_unscaling(self):
args = get_args()
timers = get_timers()
# <<<
main_grads = []
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
for model_module in model:
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# fp32 params from float16 ones.
for main_group in self.fp32_from_float16_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
# >>>
# raise Exception("[main] ready for weight sync?")
# <<<
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
# Append fp32 parameters.
for main_group in self.fp32_from_fp32_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
return main_grads
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
# >>>
raise Exception("[main] ready for t5 sync?")
# <<<
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
def _get_model_and_main_params_data_float16(self):
model_data = []
main_data = []
for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
return model_data, main_data
def gather_params(self, ITERATION):
pass
def _copy_model_grads_to_main_grads(self, ITERATION):
# This only needs to be done for the float16 group.
......@@ -653,49 +617,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# })
# <<<
def _collect_main_grad_data_for_unscaling(self):
main_grads = []
# fp32 params from float16 ones.
for main_group in self.fp32_from_float16_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# Append fp32 parameters.
for main_group in self.fp32_from_fp32_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
return main_grads
def _get_model_and_main_params_data_float16(self):
model_data = []
main_data = []
for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
return model_data, main_data
def _copy_main_params_to_model_params(self, ITERATION):
# Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(this=main_data, that=model_data,
overflow_buf=self._dummy_overflow_buf)
# >>>
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "** branch **" : "** main. **",
# "ITERATION" : ITERATION,
# "model params" : [p for m in self.models for p in m.parameters()],
# })
# <<<
def _copy_model_params_to_main_params(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment