Commit 9546d8f0 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

passing 'model_parallel_group' to clip_grads, count_zeros

parent e46230dc
...@@ -91,18 +91,6 @@ def get_megatron_optimizer(model, ...@@ -91,18 +91,6 @@ def get_megatron_optimizer(model,
scale_lr_cond, scale_lr_cond,
lr_mult) lr_mult)
# >>>
# params = [ p for m in model for p in m.parameters() ]
# pax(0, {
# "params" : [ (p.tensor_model_parallel, tp(p)) for p in params ],
# })
# <<<
# >>>
# if args.use_distributed_optimizer:
# optimizer = DistributedFusedAdam(param_groups)
# elif args.optimizer == 'adam':
# <<<
if args.optimizer == 'adam': if args.optimizer == 'adam':
optimizer = Adam(param_groups, optimizer = Adam(param_groups,
lr=args.lr, lr=args.lr,
...@@ -123,7 +111,7 @@ def get_megatron_optimizer(model, ...@@ -123,7 +111,7 @@ def get_megatron_optimizer(model,
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
params_have_main_grad = True params_have_main_grad = True
if args.fp16 or args.bf16: if args.fp16 or args.bf16 or args.use_distributed_optimizer:
# Grad scaler: # Grad scaler:
# if loss-scale is provided, instantiate the constant scaler. # if loss-scale is provided, instantiate the constant scaler.
...@@ -148,10 +136,10 @@ def get_megatron_optimizer(model, ...@@ -148,10 +136,10 @@ def get_megatron_optimizer(model,
# Megatron optimizer. # Megatron optimizer.
# >>> # >>>
opt_ty = Float16DistributedOptimizer \ opt_ty = DistributedOptimizer \
if args.use_distributed_optimizer \ if args.use_distributed_optimizer else \
else Float16OptimizerWithFloat16Params Float16OptimizerWithFloat16Params
opt = opt_ty(optimizer, return opt_ty(optimizer,
args.clip_grad, args.clip_grad,
args.log_num_zeros_in_grad, args.log_num_zeros_in_grad,
params_have_main_grad, params_have_main_grad,
...@@ -159,20 +147,16 @@ def get_megatron_optimizer(model, ...@@ -159,20 +147,16 @@ def get_megatron_optimizer(model,
args.bf16, args.bf16,
grad_scaler, grad_scaler,
model) model)
# >>>
# opt.debug_main_param_sum(0, "after init")
# opt.debug_main_grad_sum(0, "after init")
# <<<
return opt
# <<< # <<<
# FP32. # FP32.
# >>> # >>>
opt_ty = Float32DistributedOptimizer \ # opt_ty = Float32DistributedOptimizer \
if args.use_distributed_optimizer \ # if args.use_distributed_optimizer \
else Float32Optimizer # else Float32Optimizer
return opt_ty(optimizer, args.clip_grad, # return opt_ty(optimizer, args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp)
# <<< # <<<
return Float32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp)
...@@ -21,7 +21,9 @@ from torch._six import inf ...@@ -21,7 +21,9 @@ from torch._six import inf
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
from megatron import mpu # >>>
# from megatron import mpu
# <<<
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
...@@ -31,7 +33,9 @@ from lutil import pax, tp ...@@ -31,7 +33,9 @@ from lutil import pax, tp
DEBUG_ITERATION = 1 DEBUG_ITERATION = 1
# <<< # <<<
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None): def clip_grad_norm_fp32(parameters, max_norm, norm_type=2,
model_parallel_group=None,
ITERATION=None):
"""Clips gradient norm of an iterable of parameters whose gradients """Clips gradient norm of an iterable of parameters whose gradients
are in fp32. are in fp32.
...@@ -45,13 +49,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None): ...@@ -45,13 +49,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
max_norm (float or int): max norm of the gradients max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm. infinity norm.
model_parallel_group (group): due to the nature of the distributed
optimizer, this is passed as an argument.
Returns: Returns:
Total norm of the parameters (viewed as a single vector). Total norm of the parameters (viewed as a single vector).
""" """
# >>> # >>>
raise Exception("currently debugging ... don't call me.") # raise Exception("currently debugging ... don't call me.")
# <<< # <<<
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
...@@ -75,26 +81,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None): ...@@ -75,26 +81,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
grads.append(grad) grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate: if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad) grads_for_norm.append(grad)
# >>>
# else:
# pax(1, {
# "grad_not_none" : grad_not_none,
# "is_not_shared" : is_not_shared,
# "is_not_tp_duplicate" : is_not_tp_duplicate,
# })
# <<<
# >>>
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "[LOC]" : "[** BEFORE CALC NORM **]",
# "[ITERATION]" : ITERATION,
# "max_norm" : max_norm,
# "parameters" : parameters,
# # "grads" : grads,
# "grads_for_norm" : grads_for_norm,
# })
# <<<
# Norm parameters. # Norm parameters.
max_norm = float(max_norm) max_norm = float(max_norm)
...@@ -108,7 +94,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None): ...@@ -108,7 +94,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
# Take max across all model-parallel GPUs. # Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda, torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group()) group=model_parallel_group)
total_norm = total_norm_cuda[0].item() total_norm = total_norm_cuda[0].item()
else: else:
...@@ -117,13 +103,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None): ...@@ -117,13 +103,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
# Use apex's multi-tensor applier for efficiency reasons. # Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list # Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel. # and performs the operation on that list all in one kernel.
# >>>
# pax(1, {
# # "fn" : amp_C.multi_tensor_l2norm,
# "dummy_overflow_buf" : tp(dummy_overflow_buf),
# "grads_for_norm" : grads_for_norm,
# })
# <<<
grad_norm, _ = multi_tensor_applier( grad_norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm, amp_C.multi_tensor_l2norm,
dummy_overflow_buf, dummy_overflow_buf,
...@@ -139,18 +118,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None): ...@@ -139,18 +118,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
grad_norm = torch.norm(grad, norm_type) grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type total_norm += grad_norm ** norm_type
# >>>
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "[LOC]" : "[** CALC NORM **]",
# "[ITERATION]" : ITERATION,
# "max_norm" : max_norm,
# "norm_type" : norm_type,
# "grad_norm" : tp(grad_norm),
# "total_norm" : tp(total_norm),
# })
# <<<
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
# >>> # >>>
from megatron import get_args from megatron import get_args
...@@ -161,22 +128,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None): ...@@ -161,22 +128,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
else: else:
torch.distributed.all_reduce(total_norm, torch.distributed.all_reduce(total_norm,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group()) group=model_parallel_group)
# <<< # <<<
total_norm = total_norm.item() ** (1.0 / norm_type) total_norm = total_norm.item() ** (1.0 / norm_type)
# >>>
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "[LOC]" : "[** AFTER REDUCE. **]",
# "[ITERATION]" : ITERATION,
# "max_norm" : max_norm,
# "norm_type" : norm_type,
# "grad_norm" : grad_norm.item(),
# "total_norm" : total_norm,
# })
# <<<
# Scale. # Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6) clip_coeff = max_norm / (total_norm + 1.0e-6)
if clip_coeff < 1.0: if clip_coeff < 1.0:
...@@ -186,22 +141,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None): ...@@ -186,22 +141,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
[grads, grads], [grads, grads],
clip_coeff) clip_coeff)
# >>>
# # from pygit2 import Repository
# if ITERATION == DEBUG_ITERATION:
# pax(1, {
# "[LOC]" : "[** CLIP / FINAL **]",
# "[ITERATION]" : ITERATION,
# "grads" : grads,
# "clip_coeff" : tp(clip_coeff),
# # "repo" : Repository('.').head.shorthand,
# })
# <<<
return total_norm return total_norm
def count_zeros_fp32(parameters): def count_zeros_fp32(parameters, model_parallel_group):
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
parameters = [parameters] parameters = [parameters]
...@@ -231,7 +174,7 @@ def count_zeros_fp32(parameters): ...@@ -231,7 +174,7 @@ def count_zeros_fp32(parameters):
else: else:
torch.distributed.all_reduce(total_num_zeros, torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group()) group=model_parallel_group)
# <<< # <<<
total_num_zeros = total_num_zeros.item() total_num_zeros = total_num_zeros.item()
......
...@@ -17,8 +17,13 @@ ...@@ -17,8 +17,13 @@
import math import math
import torch
from megatron import get_args from megatron import get_args
from megatron import get_timers
from megatron import mpu
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
# >>> # >>>
from lutil import pax, tp from lutil import pax, tp
...@@ -40,7 +45,8 @@ class Shard: ...@@ -40,7 +45,8 @@ class Shard:
# class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params): # class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
# class Float16DistributedOptimizer(MegatronOptimizer): # class Float16DistributedOptimizer(MegatronOptimizer):
# class Float16DistributedOptimizer(BaseFloat16Optimizer): # class Float16DistributedOptimizer(BaseFloat16Optimizer):
class DistributedOptimizer(MegatronOptimizer): # class DistributedOptimizer(MegatronOptimizer):
class DistributedOptimizer(MixedPrecisionOptimizer):
@classmethod @classmethod
def get_model_gbuf_param_shard_map(cls, model, dtype, gbuf_world_shard): def get_model_gbuf_param_shard_map(cls, model, dtype, gbuf_world_shard):
......
...@@ -98,14 +98,23 @@ class MegatronOptimizer(ABC): ...@@ -98,14 +98,23 @@ class MegatronOptimizer(ABC):
return params return params
def get_model_parallel_group(self):
'''Default returned here, but the distributed optimizer overrides this.'''
return mpu.get_model_parallel_group()
def clip_grad_norm(self, clip_grad, ITERATION): def clip_grad_norm(self, clip_grad, ITERATION):
params = self.get_parameters() params = self.get_parameters()
return clip_grad_norm_fp32(params, clip_grad, ITERATION = ITERATION) return clip_grad_norm_fp32(
params, clip_grad,
model_parallel_group=self.get_model_parallel_group(),
ITERATION = ITERATION)
def count_zeros(self): def count_zeros(self):
params = self.get_parameters() params = self.get_parameters()
return count_zeros_fp32(params) return count_zeros_fp32(params,
model_parallel_group=self.get_model_parallel_group())
@abstractmethod @abstractmethod
...@@ -171,7 +180,7 @@ class MegatronOptimizer(ABC): ...@@ -171,7 +180,7 @@ class MegatronOptimizer(ABC):
def step(self): def step(self):
pass pass
def gather_params(self): def gather_params(self, ITERATION):
pass pass
def reduce_grads(self, model): def reduce_grads(self, model):
...@@ -282,10 +291,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -282,10 +291,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self._scale_one = torch.cuda.FloatTensor([1.0]) self._scale_one = torch.cuda.FloatTensor([1.0])
@abstractmethod
def get_model_parallel_group(self, state_dict):
pass
def get_loss_scale(self): def get_loss_scale(self):
if self.grad_scaler is None: if self.grad_scaler is None:
return self._scale_one return self._scale_one
...@@ -296,7 +301,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -296,7 +301,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self._copy_model_params_to_main_params() self._copy_model_params_to_main_params()
def _unscale_main_grads_and_check_for_nan(self, group): def _unscale_main_grads_and_check_for_nan(self):
# Collect main grads. # Collect main grads.
main_grads = self._collect_main_grad_data_for_unscaling() main_grads = self._collect_main_grad_data_for_unscaling()
...@@ -528,10 +533,6 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -528,10 +533,6 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
def get_model_parallel_group(self):
return mpu.get_model_parallel_group()
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e., """We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups. We additionally zero float16_groups & fp32_from_fp32_groups. We additionally zero
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment