Commit 772a4a2d authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

encapsulated 'get_model_parallel_group()'; renamed BaseFloat16Optimizer -> MixedPrecisionOptimizer

parent 6e888151
......@@ -321,45 +321,51 @@ class DistributedOptimizer(MegatronOptimizer):
# Initialize main params.
self._copy_model_params_to_main_params()
@staticmethod
def has_nan_debug(tensors):
if isinstance(tensors, torch.Tensor):
tensors = [ tensors ]
assert isinstance(tensors, list)
has_nans = [ (not torch.all(torch.isfinite(t)).item()) for t in tensors ]
has_nan = any(has_nans)
return has_nan
def get_local_model_param_views(self):
'''** FOR DEBUGGING. **'''
model_param_views = []
for group_index, opt_group_shard in enumerate(self.opt_group_shards):
for param, opt_shard in opt_group_shard["param_map"].items():
model_index, dtype = self.param_gbuf_map[param]
gbuf_shard_map = \
self.model_gbuf_shards[model_index][dtype]["param_map"][param]
model_param_shard = gbuf_shard_map["param"]
model_param_views.append(
param.view(-1)[model_param_shard.start:model_param_shard.end])
return model_param_views
def get_local_model_grad_views(self):
'''** FOR DEBUGGING. **'''
model_grad_views = []
for group_index, opt_group_shard in enumerate(self.opt_group_shards):
for param, opt_shard in opt_group_shard["param_map"].items():
model_index, dtype = self.param_gbuf_map[param]
gbuf = self.models[model_index]._grad_buffers[dtype].data
gbuf_shard_map = \
self.model_gbuf_shards[model_index][dtype]["param_map"][param]
gbuf_world_shard = gbuf_shard_map["gbuf_world"]
model_grad_views.append(
gbuf[gbuf_world_shard.start:gbuf_world_shard.end])
return model_grad_views
def get_world_model_params(self):
'''** FOR DEBUGGING. **'''
return [ p for m in self.models for p in m.parameters() ]
def get_world_model_grads(self):
'''** FOR DEBUGGING. **'''
return [ p.main_grad for p in self.get_world_model_params() ]
def get_model_parallel_group(self):
# >>>
# i.e., no param replication across this group
# <<<
return None
# @staticmethod
# def has_nan_debug(tensors):
# if isinstance(tensors, torch.Tensor):
# tensors = [ tensors ]
# assert isinstance(tensors, list)
# has_nans = [ (not torch.all(torch.isfinite(t)).item()) for t in tensors ]
# has_nan = any(has_nans)
# return has_nan
# def get_local_model_param_views(self):
# '''** FOR DEBUGGING. **'''
# model_param_views = []
# for group_index, opt_group_shard in enumerate(self.opt_group_shards):
# for param, opt_shard in opt_group_shard["param_map"].items():
# model_index, dtype = self.param_gbuf_map[param]
# gbuf_shard_map = \
# self.model_gbuf_shards[model_index][dtype]["param_map"][param]
# model_param_shard = gbuf_shard_map["param"]
# model_param_views.append(
# param.view(-1)[model_param_shard.start:model_param_shard.end])
# return model_param_views
# def get_local_model_grad_views(self):
# '''** FOR DEBUGGING. **'''
# model_grad_views = []
# for group_index, opt_group_shard in enumerate(self.opt_group_shards):
# for param, opt_shard in opt_group_shard["param_map"].items():
# model_index, dtype = self.param_gbuf_map[param]
# gbuf = self.models[model_index]._grad_buffers[dtype].data
# gbuf_shard_map = \
# self.model_gbuf_shards[model_index][dtype]["param_map"][param]
# gbuf_world_shard = gbuf_shard_map["gbuf_world"]
# model_grad_views.append(
# gbuf[gbuf_world_shard.start:gbuf_world_shard.end])
# return model_grad_views
# def get_world_model_params(self):
# '''** FOR DEBUGGING. **'''
# return [ p for m in self.models for p in m.parameters() ]
# def get_world_model_grads(self):
# '''** FOR DEBUGGING. **'''
# return [ p.main_grad for p in self.get_world_model_params() ]
def get_main_params(self):
return [ g["params"][0] for g in self.optimizer.param_groups ]
......
......@@ -182,7 +182,8 @@ class MegatronOptimizer(ABC):
param_groups = property(_get_param_groups, _set_param_groups)
class BaseFloat16Optimizer(MegatronOptimizer):
# class BaseFloat16Optimizer(MegatronOptimizer):
class MixedPrecisionOptimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
......@@ -222,6 +223,10 @@ class BaseFloat16Optimizer(MegatronOptimizer):
self._scale_one = torch.cuda.FloatTensor([1.0])
@abstractmethod
def get_model_parallel_group(self, state_dict):
pass
def get_loss_scale(self):
if self.grad_scaler is None:
return self._scale_one
......@@ -232,7 +237,7 @@ class BaseFloat16Optimizer(MegatronOptimizer):
self._copy_model_params_to_main_params()
def _unscale_main_grads_and_check_for_nan(self):
def _unscale_main_grads_and_check_for_nan(self, group):
# Collect main grads.
main_grads = self._collect_main_grad_data_for_unscaling()
......@@ -246,13 +251,14 @@ class BaseFloat16Optimizer(MegatronOptimizer):
main_grads, self.found_inf, self.grad_scaler.inv_scale)
# Update across all model parallel instances.
# >>>
if args.use_# >>>
# torch.distributed.all_reduce(self.found_inf,
# op=torch.distributed.ReduceOp.MAX,
# group=mpu.get_model_parallel_group())
# +++
torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX)
op=torch.distributed.ReduceOp.MAX,
group=self.get_model_parallel_group())
# <<<
# Check for nan.
......@@ -517,6 +523,10 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# <<<
def get_model_parallel_group(self):
return mpu.get_model_parallel_group())
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups. We additionally zero
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment