Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
db88a27b
Commit
db88a27b
authored
Jan 05, 2021
by
mohammad
Browse files
addressed Jareds and Deepaks comments
parent
512337f5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
25 deletions
+14
-25
megatron/optimizer/clip_grads.py
megatron/optimizer/clip_grads.py
+3
-0
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+6
-16
megatron/training.py
megatron/training.py
+5
-9
No files found.
megatron/optimizer/clip_grads.py
View file @
db88a27b
...
...
@@ -83,6 +83,9 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
else
:
if
norm_type
==
2.0
:
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel.
grad_norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
dummy_overflow_buf
,
...
...
megatron/optimizer/optimizer.py
View file @
db88a27b
...
...
@@ -78,6 +78,7 @@ class MegatronOptimizer(ABC):
@
abstractmethod
def
get_loss_scale
(
self
):
"""The output should be a cuda tensor of size 1."""
pass
def
scale_loss
(
self
,
loss
):
...
...
@@ -90,6 +91,11 @@ class MegatronOptimizer(ABC):
@
abstractmethod
def
reload_model_params
(
self
):
"""Refreshes any internal state from the current model parameters.
Call whenever the parameters are changed outside of the optimizer.
For example, when we load a model from a checkpoint without loading
the optimizer, the model parameters are updated but for fp16 optimizer
with main parameters, the main parameters need to also be updated."""
pass
@
abstractmethod
...
...
@@ -289,54 +295,38 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
timers
=
get_timers
()
# ==================================================
# Copy gradients from model params to main params.
# ==================================================
timers
(
'optimizer-copy-to-main-grad'
).
start
()
self
.
_copy_model_grads_to_main_grads
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# ==============================
# Unscale and check for inf/nan.
# ==============================
timers
(
'optimizer-unscale-and-check-inf'
).
start
()
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
# ==================================
# We are done with scaling gradients
# so we can update the loss scale.
# ==================================
self
.
grad_scaler
.
update
(
found_inf_flag
)
# =====================================
# If we found inf/nan, skip the update.
# =====================================
if
found_inf_flag
:
return
False
# ==========================
# Clip the main gradients.
# ==========================
timers
(
'optimizer-clip-main-grad'
).
start
()
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
# ===================
# Step the optimizer.
# ===================
self
.
optimizer
.
step
()
# =================================
# Update params from main params.
# =================================
timers
(
'optimizer-copy-main-to-model-params'
).
start
()
self
.
_copy_main_params_to_model_params
()
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
# ==================
# Successful update.
# ==================
return
True
...
...
megatron/training.py
View file @
db88a27b
...
...
@@ -703,10 +703,9 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
writer
.
add_scalar
(
key
,
loss_dict
[
key
],
iteration
)
writer
.
add_scalar
(
key
+
' vs samples'
,
loss_dict
[
key
],
args
.
consumed_train_samples
)
if
args
.
fp16
:
writer
.
add_scalar
(
'loss-scale'
,
loss_scale
,
iteration
)
writer
.
add_scalar
(
'loss-scale vs samples'
,
loss_scale
,
args
.
consumed_train_samples
)
writer
.
add_scalar
(
'loss-scale'
,
loss_scale
,
iteration
)
writer
.
add_scalar
(
'loss-scale vs samples'
,
loss_scale
,
args
.
consumed_train_samples
)
timers
.
write
(
timers_to_log
,
writer
,
iteration
,
normalizer
=
total_iterations
)
...
...
@@ -732,8 +731,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
if
avg
>
0.0
:
log_string
+=
' {}: {:.6E} |'
.
format
(
key
,
avg
)
total_loss_dict
[
key
]
=
torch
.
cuda
.
FloatTensor
([
0.0
])
if
args
.
fp16
:
log_string
+=
' loss scale: {:.1f} |'
.
format
(
loss_scale
)
log_string
+=
' loss scale: {:.1f} |'
.
format
(
loss_scale
)
log_string
+=
' number of skipped iterations: {:3d} |'
.
format
(
total_loss_dict
[
skipped_iters_key
])
log_string
+=
' number of nan iterations: {:3d} |'
.
format
(
...
...
@@ -797,9 +795,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
get_num_microbatches
()
# Logging.
loss_scale
=
None
if
args
.
fp16
:
loss_scale
=
optimizer
.
get_loss_scale
().
item
()
loss_scale
=
optimizer
.
get_loss_scale
().
item
()
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
optimizer
.
param_groups
[
0
][
'lr'
],
iteration
,
loss_scale
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment