"router/vscode:/vscode.git/clone" did not exist on "f7ac394935f9ce502d827a1e8b3be2396c44f950"
Commit 772a4a2d authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

encapsulated 'get_model_parallel_group()'; renamed BaseFloat16Optimizer -> MixedPrecisionOptimizer

parent 6e888151
...@@ -321,45 +321,51 @@ class DistributedOptimizer(MegatronOptimizer): ...@@ -321,45 +321,51 @@ class DistributedOptimizer(MegatronOptimizer):
# Initialize main params. # Initialize main params.
self._copy_model_params_to_main_params() self._copy_model_params_to_main_params()
@staticmethod def get_model_parallel_group(self):
def has_nan_debug(tensors): # >>>
if isinstance(tensors, torch.Tensor): # i.e., no param replication across this group
tensors = [ tensors ] # <<<
assert isinstance(tensors, list) return None
has_nans = [ (not torch.all(torch.isfinite(t)).item()) for t in tensors ]
has_nan = any(has_nans) # @staticmethod
return has_nan # def has_nan_debug(tensors):
def get_local_model_param_views(self): # if isinstance(tensors, torch.Tensor):
'''** FOR DEBUGGING. **''' # tensors = [ tensors ]
model_param_views = [] # assert isinstance(tensors, list)
for group_index, opt_group_shard in enumerate(self.opt_group_shards): # has_nans = [ (not torch.all(torch.isfinite(t)).item()) for t in tensors ]
for param, opt_shard in opt_group_shard["param_map"].items(): # has_nan = any(has_nans)
model_index, dtype = self.param_gbuf_map[param] # return has_nan
gbuf_shard_map = \ # def get_local_model_param_views(self):
self.model_gbuf_shards[model_index][dtype]["param_map"][param] # '''** FOR DEBUGGING. **'''
model_param_shard = gbuf_shard_map["param"] # model_param_views = []
model_param_views.append( # for group_index, opt_group_shard in enumerate(self.opt_group_shards):
param.view(-1)[model_param_shard.start:model_param_shard.end]) # for param, opt_shard in opt_group_shard["param_map"].items():
return model_param_views # model_index, dtype = self.param_gbuf_map[param]
def get_local_model_grad_views(self): # gbuf_shard_map = \
'''** FOR DEBUGGING. **''' # self.model_gbuf_shards[model_index][dtype]["param_map"][param]
model_grad_views = [] # model_param_shard = gbuf_shard_map["param"]
for group_index, opt_group_shard in enumerate(self.opt_group_shards): # model_param_views.append(
for param, opt_shard in opt_group_shard["param_map"].items(): # param.view(-1)[model_param_shard.start:model_param_shard.end])
model_index, dtype = self.param_gbuf_map[param] # return model_param_views
gbuf = self.models[model_index]._grad_buffers[dtype].data # def get_local_model_grad_views(self):
gbuf_shard_map = \ # '''** FOR DEBUGGING. **'''
self.model_gbuf_shards[model_index][dtype]["param_map"][param] # model_grad_views = []
gbuf_world_shard = gbuf_shard_map["gbuf_world"] # for group_index, opt_group_shard in enumerate(self.opt_group_shards):
model_grad_views.append( # for param, opt_shard in opt_group_shard["param_map"].items():
gbuf[gbuf_world_shard.start:gbuf_world_shard.end]) # model_index, dtype = self.param_gbuf_map[param]
return model_grad_views # gbuf = self.models[model_index]._grad_buffers[dtype].data
def get_world_model_params(self): # gbuf_shard_map = \
'''** FOR DEBUGGING. **''' # self.model_gbuf_shards[model_index][dtype]["param_map"][param]
return [ p for m in self.models for p in m.parameters() ] # gbuf_world_shard = gbuf_shard_map["gbuf_world"]
def get_world_model_grads(self): # model_grad_views.append(
'''** FOR DEBUGGING. **''' # gbuf[gbuf_world_shard.start:gbuf_world_shard.end])
return [ p.main_grad for p in self.get_world_model_params() ] # return model_grad_views
# def get_world_model_params(self):
# '''** FOR DEBUGGING. **'''
# return [ p for m in self.models for p in m.parameters() ]
# def get_world_model_grads(self):
# '''** FOR DEBUGGING. **'''
# return [ p.main_grad for p in self.get_world_model_params() ]
def get_main_params(self): def get_main_params(self):
return [ g["params"][0] for g in self.optimizer.param_groups ] return [ g["params"][0] for g in self.optimizer.param_groups ]
......
...@@ -182,7 +182,8 @@ class MegatronOptimizer(ABC): ...@@ -182,7 +182,8 @@ class MegatronOptimizer(ABC):
param_groups = property(_get_param_groups, _set_param_groups) param_groups = property(_get_param_groups, _set_param_groups)
class BaseFloat16Optimizer(MegatronOptimizer): # class BaseFloat16Optimizer(MegatronOptimizer):
class MixedPrecisionOptimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
...@@ -222,6 +223,10 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -222,6 +223,10 @@ class BaseFloat16Optimizer(MegatronOptimizer):
self._scale_one = torch.cuda.FloatTensor([1.0]) self._scale_one = torch.cuda.FloatTensor([1.0])
@abstractmethod
def get_model_parallel_group(self, state_dict):
pass
def get_loss_scale(self): def get_loss_scale(self):
if self.grad_scaler is None: if self.grad_scaler is None:
return self._scale_one return self._scale_one
...@@ -232,7 +237,7 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -232,7 +237,7 @@ class BaseFloat16Optimizer(MegatronOptimizer):
self._copy_model_params_to_main_params() self._copy_model_params_to_main_params()
def _unscale_main_grads_and_check_for_nan(self): def _unscale_main_grads_and_check_for_nan(self, group):
# Collect main grads. # Collect main grads.
main_grads = self._collect_main_grad_data_for_unscaling() main_grads = self._collect_main_grad_data_for_unscaling()
...@@ -246,13 +251,14 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -246,13 +251,14 @@ class BaseFloat16Optimizer(MegatronOptimizer):
main_grads, self.found_inf, self.grad_scaler.inv_scale) main_grads, self.found_inf, self.grad_scaler.inv_scale)
# Update across all model parallel instances. # Update across all model parallel instances.
# >>> if args.use_# >>>
# torch.distributed.all_reduce(self.found_inf, # torch.distributed.all_reduce(self.found_inf,
# op=torch.distributed.ReduceOp.MAX, # op=torch.distributed.ReduceOp.MAX,
# group=mpu.get_model_parallel_group()) # group=mpu.get_model_parallel_group())
# +++ # +++
torch.distributed.all_reduce(self.found_inf, torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX) op=torch.distributed.ReduceOp.MAX,
group=self.get_model_parallel_group())
# <<< # <<<
# Check for nan. # Check for nan.
...@@ -517,6 +523,10 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -517,6 +523,10 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# <<< # <<<
def get_model_parallel_group(self):
return mpu.get_model_parallel_group())
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e., """We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups. We additionally zero float16_groups & fp32_from_fp32_groups. We additionally zero
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment