Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
c88bc979
Commit
c88bc979
authored
Mar 09, 2022
by
Lawrence McAfee
Browse files
updated FP32Optimizer for latest changes.
parent
7ac342b7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
25 deletions
+19
-25
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+5
-4
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+14
-21
No files found.
megatron/optimizer/__init__.py
View file @
c88bc979
...
...
@@ -156,7 +156,8 @@ def get_megatron_optimizer(model,
# else Float32Optimizer
# return opt_ty(optimizer, args.clip_grad,
# <<<
return
Float32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
)
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
,
model
)
megatron/optimizer/optimizer.py
View file @
c88bc979
...
...
@@ -75,7 +75,8 @@ class MegatronOptimizer(ABC):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
):
use_contiguous_buffers_in_local_ddp
,
models
):
"""Input optimizer is the base optimizer for example Adam."""
self
.
optimizer
=
optimizer
...
...
@@ -86,6 +87,10 @@ class MegatronOptimizer(ABC):
self
.
params_have_main_grad
=
params_have_main_grad
self
.
use_contiguous_buffers_in_local_ddp
=
use_contiguous_buffers_in_local_ddp
# 'models' are retained for access to the contiguous grad buffers.
# (see distributed optimizer)
self
.
models
=
models
if
self
.
use_contiguous_buffers_in_local_ddp
:
assert
self
.
params_have_main_grad
,
\
"use of contiguous buffer requires that params have main grad"
...
...
@@ -260,11 +265,9 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
)
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
models
)
# >>>
self
.
models
=
models
# <<<
self
.
bf16
=
bf16
self
.
grad_scaler
=
grad_scaler
# None grad scaler is only supported for bf16.
...
...
@@ -382,8 +385,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
@
torch
.
no_grad
()
def
step
(
self
,
args
,
timers
,
ITERATION
):
# timers = get_timers()
# >>>
# self.debug_model(ITERATION, "before copy grad.", 0)
# self.debug_main(ITERATION, "before copy grad.", 0)
...
...
@@ -608,16 +609,6 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
if
not
self
.
use_contiguous_buffers_in_local_ddp
:
model_param
.
main_grad
=
None
# >>>
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "** branch **" : "** main. **",
# "ITERATION" : ITERATION,
# "model grads" :
# [ p.main_grad for m in self.models for p in m.parameters() ],
# })
# <<<
def
_copy_main_params_to_model_params
(
self
,
ITERATION
):
# Only needed for the float16 params.
...
...
@@ -679,11 +670,13 @@ class FP32Optimizer(MegatronOptimizer):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
):
use_contiguous_buffers_in_local_ddp
,
models
):
super
(
FP32Optimizer
,
self
).
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
)
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
models
)
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
1.0
])
...
...
@@ -700,7 +693,7 @@ class FP32Optimizer(MegatronOptimizer):
@
torch
.
no_grad
()
def
step
(
self
,
args
,
timers
):
def
step
(
self
,
args
,
timers
,
ITERATION
):
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
...
...
@@ -719,7 +712,7 @@ class FP32Optimizer(MegatronOptimizer):
# Clip gradients.
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
,
ITERATION
)
# count the zeros in the grads
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment