Commit f6811e28 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

guard 'gather_params()' with 'if update_successful:'

parent c64098ef
...@@ -177,12 +177,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -177,12 +177,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def allocate_main_param_shards(cls, opt_group_shards): def allocate_main_param_shards(cls, opt_group_shards):
# Allocator method. # Allocator method.
# >>>
allocate_shard = lambda shard_size, dtype : torch.empty( allocate_shard = lambda shard_size, dtype : torch.empty(
(shard_size,), (shard_size,),
dtype = dtype, dtype = dtype,
device = torch.cuda.current_device(), device = torch.cuda.current_device(),
requires_grad = True) requires_grad = True)
# allocate_shard = lambda shard_size, dtype : torch.zeros(
# (shard_size,),
# dtype = dtype,
# device = torch.cuda.current_device(),
# requires_grad = True)
# <<<
# Allocate each group's param/grad shard. # Allocate each group's param/grad shard.
for group_index, group_shard in enumerate(opt_group_shards): for group_index, group_shard in enumerate(opt_group_shards):
......
...@@ -330,6 +330,25 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -330,6 +330,25 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# Check for nan. # Check for nan.
found_inf_flag = (self.found_inf.item() > 0) found_inf_flag = (self.found_inf.item() > 0)
# >>>
# if self.grad_scaler.scale <= 131072:
# pax(0, {
# # "grad_scaler" : self.grad_scaler,
# # "found_inf_flag" : found_inf_flag,
# "model_params" : [
# p
# for m in self.models
# for p in m.parameters()
# ],
# "model_grads" : [
# p.main_grad
# for m in self.models
# for p in m.parameters()
# ],
# # "main_grads" : main_grads,
# })
# <<<
return found_inf_flag return found_inf_flag
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
...@@ -411,6 +430,10 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -411,6 +430,10 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
found_inf_flag = self._unscale_main_grads_and_check_for_nan() found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop() timers('optimizer-unscale-and-check-inf').stop()
# >>>
# <<<
# We are done with scaling gradients # We are done with scaling gradients
# so we can update the loss scale. # so we can update the loss scale.
self.grad_scaler.update(found_inf_flag) self.grad_scaler.update(found_inf_flag)
......
...@@ -453,7 +453,8 @@ def train_step(forward_step_func, data_iterator, ...@@ -453,7 +453,8 @@ def train_step(forward_step_func, data_iterator,
# >>> # >>>
# Gather params. # Gather params.
optimizer.gather_model_params(args, timers, ITERATION) if update_successful:
optimizer.gather_model_params(args, timers, ITERATION)
# <<< # <<<
# >>> # >>>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment