Commit a7782b21 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

fixed clip_grad_norm bug

parent d1f5776a
...@@ -21,6 +21,7 @@ from megatron.model import LayerNorm ...@@ -21,6 +21,7 @@ from megatron.model import LayerNorm
# >>> # >>>
# from .distributed_fused_adam import DistributedFusedAdam # from .distributed_fused_adam import DistributedFusedAdam
from lutil import pax, tp
# <<< # <<<
from .grad_scaler import ConstantGradScaler, DynamicGradScaler from .grad_scaler import ConstantGradScaler, DynamicGradScaler
# >>> # >>>
...@@ -94,14 +95,9 @@ def get_megatron_optimizer(model, ...@@ -94,14 +95,9 @@ def get_megatron_optimizer(model,
lr_mult) lr_mult)
# >>> # >>>
# from lutil import pax # params = [ p for m in model for p in m.parameters() ]
# pax(0, { # pax(0, {
# "model" : model, # "params" : [ (p.tensor_model_parallel, tp(p)) for p in params ],
# # "param_groups" : param_groups,
# # "param_groups / 0" : param_groups[0],
# # "param_groups / 0 / params" : param_groups[0]["params"],
# # "param_groups / 1" : param_groups[1],
# # "param_groups / 1 / params" : param_groups[1]["params"],
# }) # })
# <<< # <<<
......
...@@ -26,6 +26,10 @@ from megatron.model.module import param_is_not_shared ...@@ -26,6 +26,10 @@ from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
# >>>
from lutil import pax, tp
# <<<
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters whose gradients """Clips gradient norm of an iterable of parameters whose gradients
are in fp32. are in fp32.
...@@ -66,6 +70,19 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -66,6 +70,19 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
grads.append(grad) grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate: if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad) grads_for_norm.append(grad)
# >>>
# else:
# pax(1, {
# "grad_not_none" : grad_not_none,
# "is_not_shared" : is_not_shared,
# "is_not_tp_duplicate" : is_not_tp_duplicate,
# })
# <<<
# pax(1, {
# "grads" : grads,
# "grads_for_norm" : grads_for_norm,
# })
# Norm parameters. # Norm parameters.
max_norm = float(max_norm) max_norm = float(max_norm)
...@@ -88,6 +105,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -88,6 +105,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Use apex's multi-tensor applier for efficiency reasons. # Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list # Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel. # and performs the operation on that list all in one kernel.
# >>>
# pax(1, {
# # "fn" : amp_C.multi_tensor_l2norm,
# "dummy_overflow_buf" : tp(dummy_overflow_buf),
# "grads_for_norm" : grads_for_norm,
# })
# <<<
grad_norm, _ = multi_tensor_applier( grad_norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm, amp_C.multi_tensor_l2norm,
dummy_overflow_buf, dummy_overflow_buf,
......
...@@ -33,6 +33,7 @@ from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 ...@@ -33,6 +33,7 @@ from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
from lutil import pax, tp from lutil import pax, tp
# <<< # <<<
def _zero_grad_group_helper(group, set_to_none): def _zero_grad_group_helper(group, set_to_none):
"""Zero out the gradient for a group of parameters. """Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer.""" Note: copied from torch.optim.optimizer."""
...@@ -97,6 +98,13 @@ class MegatronOptimizer(ABC): ...@@ -97,6 +98,13 @@ class MegatronOptimizer(ABC):
def clip_grad_norm(self, clip_grad): def clip_grad_norm(self, clip_grad):
params = self.get_parameters() params = self.get_parameters()
# >>>
# pax(0, {
# "clip_grad" : clip_grad,
# "params": [ (p.tensor_model_parallel, tp(p)) for p in params ],
# "grads" : [ p.grad for p in params ],
# })
# <<<
return clip_grad_norm_fp32(params, clip_grad) return clip_grad_norm_fp32(params, clip_grad)
...@@ -179,7 +187,6 @@ class MegatronOptimizer(ABC): ...@@ -179,7 +187,6 @@ class MegatronOptimizer(ABC):
param_groups = property(_get_param_groups, _set_param_groups) param_groups = property(_get_param_groups, _set_param_groups)
class BaseFloat16Optimizer(MegatronOptimizer): class BaseFloat16Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
...@@ -226,6 +233,92 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -226,6 +233,92 @@ class BaseFloat16Optimizer(MegatronOptimizer):
return self.grad_scaler.scale return self.grad_scaler.scale
def _unscale_main_grads_and_check_for_nan(self):
# Collect main grads.
main_grads = self._collect_main_grad_data_for_unscaling()
# pax(1, {"main_grads": main_grads})
# Reset found inf.
self.found_inf.fill_(0.0)
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, self.found_inf, self.grad_scaler.inv_scale)
# Update across all model parallel instances.
torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
# Check for nan.
found_inf_flag = (self.found_inf.item() > 0)
# raise Exception("hi.")
return found_inf_flag
@torch.no_grad()
def step(self):
timers = get_timers()
# Copy gradients from model params to main params.
timers('optimizer-copy-to-main-grad').start()
self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if self.grad_scaler:
# Unscale and check for inf/nan.
timers('optimizer-unscale-and-check-inf').start()
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop()
# We are done with scaling gradients
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
# If we found inf/nan, skip the update.
if found_inf_flag:
return False, None, None
# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
# Step the optimizer.
self.optimizer.step()
# >>>
# from lutil import pax, tp
# pax(0, {
# "optimizer / state" :
# { hash(k):tp(v) for k,v in self.optimizer.state.items() },
# "optimizer / state / len" : len(self.optimizer.state),
# "optimizer / state / 0" : list(self.optimizer.state.values())[0],
# })
# <<<
# Update params from main params.
timers('optimizer-copy-main-to-model-params').start()
self._copy_main_params_to_model_params()
timers('optimizer-copy-main-to-model-params').stop()
# Successful update.
return True, grad_norm, num_zeros_in_grad
# class Float16OptimizerWithFloat16Params(MegatronOptimizer): # class Float16OptimizerWithFloat16Params(MegatronOptimizer):
class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
"""Float16 optimizer for fp16 and bf16 data types. """Float16 optimizer for fp16 and bf16 data types.
...@@ -254,12 +347,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -254,12 +347,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler): bf16, grad_scaler, models):
super().__init__( super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler) bf16, grad_scaler, models)
# ====================== # ======================
# main parameter stuff # main parameter stuff
...@@ -295,42 +388,16 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -295,42 +388,16 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
main_param.shared = param.shared main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy. # Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param param_group['params'][i] = main_param
# >>>
def debug():
from lutil import pax, tp
pax(0, {
"optimizer" : optimizer,
# "optimizer / state" : optimizer.state,
"optimizer / pg / 0" : optimizer.param_groups[0]["params"],
"optimizer / pg / 1" : optimizer.param_groups[1]["params"],
"param" : tp(param),
"param / hash" : hash(param),
"main_param" : tp(main_param),
"main_param / hash" : hash(main_param),
})
# <<<
# >>>
# debug()
# from lutil import pax, tp
# pax(0, {
# "param" : tp(param),
# "main_param" : tp(main_param),
# })
# <<<
fp32_from_float16_params_this_group.append(main_param) fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param. # Reset existing state dict key to the new main param.
if param in self.optimizer.state: if param in self.optimizer.state:
self.optimizer.state[main_param] \ self.optimizer.state[main_param] \
= self.optimizer.state.pop(param) = self.optimizer.state.pop(param)
# >>>
# debug()
# <<<
# fp32 params. # fp32 params.
elif param.type() == 'torch.cuda.FloatTensor': elif param.type() == 'torch.cuda.FloatTensor':
# >>> # >>>
from lutil import pax
pax(0, {"param": param}) pax(0, {"param": param})
# <<< # <<<
fp32_params_this_group.append(param) fp32_params_this_group.append(param)
...@@ -352,6 +419,16 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -352,6 +419,16 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# recast preexisting per-param state tensors # recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
# >>>
# from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
# params = self.get_parameters()
# pax(0, {
# # "params / 0" : params[0],
# "params" : [ (p.tensor_model_parallel, tp(p)) for p in params ],
# "grads" : [ (param_is_not_tensor_parallel_duplicate(p.grad), tp(p.grad)) for p in params ],
# })
# <<<
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e., """We only need to zero the model related parameters, i.e.,
...@@ -458,6 +535,10 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -458,6 +535,10 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# <<< # <<<
timers('backward-embedding-all-reduce').stop() timers('backward-embedding-all-reduce').stop()
def gather_params(self):
raise Exception("hi.")
def _copy_model_grads_to_main_grads(self): def _copy_model_grads_to_main_grads(self):
# This only needs to be done for the float16 group. # This only needs to be done for the float16 group.
for model_group, main_group in zip(self.float16_groups, for model_group, main_group in zip(self.float16_groups,
...@@ -489,31 +570,30 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -489,31 +570,30 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
if not self.use_contiguous_buffers_in_local_ddp: if not self.use_contiguous_buffers_in_local_ddp:
model_param.main_grad = None model_param.main_grad = None
def _unscale_main_grads_and_check_for_nan(self): def _collect_main_grad_data_for_unscaling(self):
main_grads = [] main_grads = []
# fp32 params fromm float16 ones.
# fp32 params from float16 ones.
for main_group in self.fp32_from_float16_groups: for main_group in self.fp32_from_float16_groups:
for main_param in main_group: for main_param in main_group:
if main_param.grad is not None: if main_param.grad is not None:
main_grads.append(main_param.grad.data) main_grads.append(main_param.grad.data)
# pax(1, {"main_grads": main_grads})
# Append fp32 parameters. # Append fp32 parameters.
for main_group in self.fp32_from_fp32_groups: for main_group in self.fp32_from_fp32_groups:
for main_param in main_group: for main_param in main_group:
if main_param.grad is not None: if main_param.grad is not None:
main_grads.append(main_param.grad.data) main_grads.append(main_param.grad.data)
# Reset found inf.
self.found_inf.fill_(0.0) # >>>
# Unscale and set found inf/nan # from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
torch._amp_foreach_non_finite_check_and_unscale_( # pax(1, {"main_grads": [ (param_is_not_tensor_parallel_duplicate(t), tp(t)) for t in main_grads ]})
main_grads, self.found_inf, self.grad_scaler.inv_scale) # <<<
# Update across all model parallel instances.
torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
# Check for nan. return main_grads
found_inf_flag = (self.found_inf.item() > 0)
return found_inf_flag
def _get_model_and_main_params_data_float16(self): def _get_model_and_main_params_data_float16(self):
...@@ -545,66 +625,6 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -545,66 +625,6 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
self._copy_model_params_to_main_params() self._copy_model_params_to_main_params()
@torch.no_grad()
def step(self):
timers = get_timers()
# Copy gradients from model params to main params.
timers('optimizer-copy-to-main-grad').start()
self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if self.grad_scaler:
# Unscale and check for inf/nan.
timers('optimizer-unscale-and-check-inf').start()
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop()
# We are done with scaling gradients
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
# If we found inf/nan, skip the update.
if found_inf_flag:
return False, None, None
# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
# Step the optimizer.
self.optimizer.step()
# >>>
# from lutil import pax, tp
# pax(0, {
# "optimizer / state" :
# { hash(k):tp(v) for k,v in self.optimizer.state.items() },
# "optimizer / state / len" : len(self.optimizer.state),
# "optimizer / state / 0" : list(self.optimizer.state.values())[0],
# })
# <<<
# Update params from main params.
timers('optimizer-copy-main-to-model-params').start()
self._copy_main_params_to_model_params()
timers('optimizer-copy-main-to-model-params').stop()
# Successful update.
return True, grad_norm, num_zeros_in_grad
def state_dict(self): def state_dict(self):
state_dict = {} state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict() state_dict['optimizer'] = self.optimizer.state_dict()
...@@ -657,10 +677,6 @@ from megatron import get_args ...@@ -657,10 +677,6 @@ from megatron import get_args
# from megatron.model import Float16Module # from megatron.model import Float16Module
# from megatron.utils import unwrap_model # from megatron.utils import unwrap_model
# >>>
from lutil import pax, tp
# <<<
# class ShardIndex: # class ShardIndex:
class Shard: class Shard:
def __init__(self, start, end): def __init__(self, start, end):
...@@ -1021,28 +1037,35 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1021,28 +1037,35 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
return shard_map return shard_map
# @classmethod @classmethod
# def get_param_size_map(cls, model_gbuf_shards): # def get_param_size_map(cls, model_gbuf_shards):
# def get_param_model_gbuf_map(cls, model_gbuf_shards):
def get_param_gbuf_map(cls, model_gbuf_shards):
# param_size_map = {} # param_size_map = {}
# for model_gbuf_shard_map in model_gbuf_shards: param_gbuf_map = {}
# for dtype, gbuf_shard_map in model_gbuf_shard_map.items(): for model_index, model_gbuf_shard_map in enumerate(model_gbuf_shards):
# for param, param_shard_map in gbuf_shard_map["param_map"].items(): for dtype, gbuf_shard_map in model_gbuf_shard_map.items():
# assert param not in param_size_map for param, param_shard_map in gbuf_shard_map["param_map"].items():
# param_size_map[param] = param_shard_map["local"].size # assert param not in param_size_map
# # pax(0, { # param_size_map[param] = param_shard_map["local"].size
# # "dtype" : dtype, param_gbuf_map[param] = (model_index, dtype)
# # "gbuf_shard_map" : gbuf_shard_map, # pax(0, {
# # "param" : tp(param), # "dtype" : dtype,
# # "param_shard_map" : param_shard_map, # "gbuf_shard_map" : gbuf_shard_map,
# # }) # "param" : tp(param),
# "param_shard_map" : param_shard_map,
# })
# pax(0, { # pax(0, {
# "model_gbuf_shards" : model_gbuf_shards, # "model_gbuf_shards" : model_gbuf_shards,
# "param_size_map" : [ (str(p.shape), s) for p, s in param_size_map.items() ], # # "param_size_map" :
# }) # # [ (str(p.shape), s) for p, s in param_size_map.items() ],
# "param_gbuf_map" : param_gbuf_map,
# })
# return param_size_map # return param_size_map
return param_gbuf_map
@classmethod @classmethod
def get_optimizer_group_shards(cls, param_groups, model_gbuf_shards): def get_optimizer_group_shards(cls, param_groups, model_gbuf_shards):
...@@ -1097,7 +1120,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1097,7 +1120,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
assert args.use_contiguous_buffers_in_local_ddp # already checked in args assert args.use_contiguous_buffers_in_local_ddp # already checked in args
# <<< # <<<
# pax(0, {"models": models}) # pax(1, {"models": models})
# # Data parallel info. # # Data parallel info.
# self.data_parallel_group = mpu.get_data_parallel_group() # self.data_parallel_group = mpu.get_data_parallel_group()
...@@ -1108,6 +1131,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1108,6 +1131,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self.model_gbuf_shards = [] self.model_gbuf_shards = []
for model_index, model in enumerate(self.models): for model_index, model in enumerate(self.models):
self.model_gbuf_shards.append(self.get_model_gbuf_shard_map(model)) self.model_gbuf_shards.append(self.get_model_gbuf_shard_map(model))
self.param_gbuf_map = self.get_param_gbuf_map(self.model_gbuf_shards)
# Optimizer shards. # Optimizer shards.
self.opt_group_shards = self.get_optimizer_group_shards( self.opt_group_shards = self.get_optimizer_group_shards(
...@@ -1127,18 +1151,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1127,18 +1151,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self.main_param_shards = [] self.main_param_shards = []
for group_index, group_shard in enumerate(self.opt_group_shards): for group_index, group_shard in enumerate(self.opt_group_shards):
# pax(0, {
# "group_index" : group_index,
# "group_shard" : group_shard,
# })
group_size = group_shard["size"] group_size = group_shard["size"]
# for dtype in model_main_dtypes ........ # ** todo: for dtype in model_main_dtypes ........ **
# Allocate shard. # Allocate shard.
main_param = allocate_shard(group_size, torch.float) main_param = allocate_shard(group_size, torch.float)
main_param.grad = allocate_shard(group_size, torch.float) main_param.grad = allocate_shard(group_size, torch.float)
self.main_param_shards.append(main_param) self.main_param_shards.append(main_param)
mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# Update optimizer group. # Update optimizer group.
self.optimizer.param_groups[group_index]["params"] = [ main_param ] self.optimizer.param_groups[group_index]["params"] = [ main_param ]
...@@ -1184,15 +1206,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1184,15 +1206,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# # }) # # })
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
params = [] model_params = []
for model in self.models: for model in self.models:
for dtype, param_map in model._grad_buffer_param_index_map.items(): for dtype, param_map in model._grad_buffer_param_index_map.items():
params.extend(param_map.keys()) model_params.extend(param_map.keys())
for main_group in self.optimizer.param_groups: # main_params = []
params.extend(main_group["params"]) # for main_group in self.optimizer.param_groups:
# main_params.extend(main_group["params"])
# _zero_grad_group_helper(params, set_to_none) _zero_grad_group_helper(model_params, set_to_none)
_zero_grad_group_helper(params, set_to_none = False) # _zero_grad_group_helper(params, set_to_none = False)
# pax(0, {"params": params}) # pax(0, {"params": params})
...@@ -1349,14 +1372,62 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1349,14 +1372,62 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# raise Exception("hi.") # raise Exception("hi.")
# <<< # <<<
def step(self): def gather_params(self):
raise Exception("step.") raise Exception("gather params.")
# def step(self):
def gather_params(self): # raise Exception("step.")
raise Exception("gather params.") def _collect_main_grad_data_for_unscaling(self):
# pax(1, {
# "main_param_shards" : self.main_param_shards,
# "grads" : [ p.grad.data for p in self.main_param_shards ],
# })
return [ p.grad.data for p in self.main_param_shards ]
def _copy_model_grads_to_main_grads(self):
for group_index, group_shard in enumerate(self.opt_group_shards):
for param, main_shard in group_shard["param_map"].items():
model_index, gbuf_dtype = self.param_gbuf_map[param]
model_shard = self.model_gbuf_shards \
[model_index][gbuf_dtype]["param_map"][param]["world"]
assert main_shard.size == model_shard.size
model_grad_tensor = \
self.models[model_index]._grad_buffers[gbuf_dtype].data
main_grad_tensor = \
self.main_param_shards[group_index].grad
model_grad_view = \
model_grad_tensor[model_shard.start:model_shard.end]
main_grad_view = \
main_grad_tensor[main_shard.start:main_shard.end]
main_grad_view.detach().copy_(model_grad_view)
# pax(0, {
# "group_index" : group_index,
# "group_shard" : group_shard,
# "param" : tp(param),
# "model_index" : model_index,
# "gbuf_dtype" : str(gbuf_dtype),
# "model_grad_tensor" : tp(model_grad_tensor),
# "main_grad_tensor" : tp(main_grad_tensor),
# "model_grad_view" : tp(model_grad_view),
# "main_grad_view" : tp(main_grad_view),
# "model_shard" : str(model_shard),
# "main_shard" : str(main_shard),
# })
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# })
# <<< # <<<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment