Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
00764415
Commit
00764415
authored
Jan 23, 2021
by
mohammad
Browse files
added grad norm to logging and tensorboard
parent
1aa2e08a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
19 deletions
+32
-19
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+9
-6
megatron/training.py
megatron/training.py
+17
-10
tasks/finetune_utils.py
tasks/finetune_utils.py
+6
-3
No files found.
megatron/optimizer/optimizer.py
View file @
00764415
...
...
@@ -70,7 +70,7 @@ class MegatronOptimizer(ABC):
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
params
.
append
(
param
)
clip_grad_norm_fp32
(
params
,
clip_grad
)
return
clip_grad_norm_fp32
(
params
,
clip_grad
)
@
abstractmethod
def
zero_grad
(
self
,
set_to_none
=
True
):
...
...
@@ -311,11 +311,13 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
# If we found inf/nan, skip the update.
if
found_inf_flag
:
return
False
return
False
,
None
# Clip the main gradients.
timers
(
'optimizer-clip-main-grad'
).
start
()
self
.
clip_grad_norm
(
self
.
clip_grad
)
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
# Step the optimizer.
...
...
@@ -327,7 +329,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
# Successful update.
return
True
return
True
,
grad_norm
def
state_dict
(
self
):
...
...
@@ -392,14 +394,15 @@ class FP32Optimizer(MegatronOptimizer):
Always return successful since there is no overflow."""
# Clip gradients.
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
self
.
clip_grad_norm
(
self
.
clip_grad
)
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
# Update parameters.
self
.
optimizer
.
step
()
# No overflow for FP32 optimizer.
return
True
return
True
,
grad_norm
def
reload_model_params
(
self
):
...
...
megatron/training.py
View file @
00764415
...
...
@@ -617,7 +617,7 @@ def train_step(forward_step_func, data_iterator,
# Update parameters.
timers
(
'optimizer'
).
start
()
update_successfull
=
optimizer
.
step
()
update_successfull
,
grad_norm
=
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
# Update learning rate.
...
...
@@ -636,12 +636,12 @@ def train_step(forward_step_func, data_iterator,
for
key
in
losses_reduced
[
0
]:
losses_reduced_for_key
=
[
x
[
key
]
for
x
in
losses_reduced
]
loss_reduced
[
key
]
=
sum
(
losses_reduced_for_key
)
/
len
(
losses_reduced_for_key
)
return
loss_reduced
,
skipped_iter
return
{},
skipped_iter
return
loss_reduced
,
skipped_iter
,
grad_norm
return
{},
skipped_iter
,
grad_norm
def
training_log
(
loss_dict
,
total_loss_dict
,
learning_rate
,
iteration
,
loss_scale
,
report_memory_flag
,
skipped_iter
):
loss_scale
,
report_memory_flag
,
skipped_iter
,
grad_norm
):
"""Log training information such as losses, timing, ...."""
args
=
get_args
()
timers
=
get_timers
()
...
...
@@ -721,6 +721,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
writer
.
add_scalar
(
'loss-scale'
,
loss_scale
,
iteration
)
writer
.
add_scalar
(
'loss-scale vs samples'
,
loss_scale
,
args
.
consumed_train_samples
)
if
grad_norm
is
not
None
:
writer
.
add_scalar
(
'grad-norm'
,
grad_norm
,
iteration
)
writer
.
add_scalar
(
'grad-norm vs samples'
,
grad_norm
,
args
.
consumed_train_samples
)
timers
.
write
(
timers_to_log
,
writer
,
iteration
,
normalizer
=
total_iterations
)
...
...
@@ -747,6 +751,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string
+=
' {}: {:.6E} |'
.
format
(
key
,
avg
)
total_loss_dict
[
key
]
=
torch
.
cuda
.
FloatTensor
([
0.0
])
log_string
+=
' loss scale: {:.1f} |'
.
format
(
loss_scale
)
if
grad_norm
is
not
None
:
log_string
+=
' grad norm: {:.3f} |'
.
format
(
grad_norm
)
log_string
+=
' number of skipped iterations: {:3d} |'
.
format
(
total_loss_dict
[
skipped_iters_key
])
log_string
+=
' number of nan iterations: {:3d} |'
.
format
(
...
...
@@ -799,11 +805,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
update_num_microbatches
(
args
.
consumed_train_samples
)
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
train_data_iterator
,
model
,
optimizer
,
lr_scheduler
)
loss_dict
,
skipped_iter
,
grad_norm
=
train_step
(
forward_step_func
,
train_data_iterator
,
model
,
optimizer
,
lr_scheduler
)
iteration
+=
1
args
.
consumed_train_samples
+=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
micro_batch_size
*
\
...
...
@@ -814,7 +820,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
optimizer
.
param_groups
[
0
][
'lr'
],
iteration
,
loss_scale
,
report_memory_flag
,
skipped_iter
)
report_memory_flag
,
skipped_iter
,
grad_norm
)
# Autoresume
if
args
.
adlr_autoresume
and
\
...
...
tasks/finetune_utils.py
View file @
00764415
...
...
@@ -179,8 +179,10 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration
=
0
# Train for one step.
losses_dict
,
skipped_iter
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr_scheduler
)
losses_dict
,
skipped_iter
,
grad_norm
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr_scheduler
)
iteration
+=
1
# Logging.
...
...
@@ -188,7 +190,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
optimizer
.
param_groups
[
0
][
'lr'
],
iteration
,
optimizer
.
get_loss_scale
().
item
(),
report_memory_flag
,
skipped_iter
)
report_memory_flag
,
skipped_iter
,
grad_norm
)
# Autoresume
if
args
.
adlr_autoresume
and
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment