Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
efa3cbcf
Commit
efa3cbcf
authored
Mar 14, 2022
by
Lawrence McAfee
Browse files
partially cleaned optimizer.py.
parent
e6120623
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
58 deletions
+12
-58
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+12
-58
No files found.
megatron/optimizer/optimizer.py
View file @
efa3cbcf
...
@@ -31,12 +31,6 @@ from megatron.utils import unwrap_model
...
@@ -31,12 +31,6 @@ from megatron.utils import unwrap_model
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
# >>>
from
lutil
import
pax
,
tp
DEBUG_ITERATION
=
1
# 10
# <<<
def
_zero_grad_group_helper
(
group
,
set_to_none
):
def
_zero_grad_group_helper
(
group
,
set_to_none
):
"""Zero out the gradient for a group of parameters.
"""Zero out the gradient for a group of parameters.
...
@@ -110,12 +104,11 @@ class MegatronOptimizer(ABC):
...
@@ -110,12 +104,11 @@ class MegatronOptimizer(ABC):
return
mpu
.
get_model_parallel_group
()
return
mpu
.
get_model_parallel_group
()
def
clip_grad_norm
(
self
,
clip_grad
,
ITERATION
):
def
clip_grad_norm
(
self
,
clip_grad
):
params
=
self
.
get_parameters
()
params
=
self
.
get_parameters
()
return
clip_grad_norm_fp32
(
return
clip_grad_norm_fp32
(
params
,
clip_grad
,
params
,
clip_grad
,
model_parallel_group
=
self
.
get_model_parallel_group
(),
model_parallel_group
=
self
.
get_model_parallel_group
())
ITERATION
=
ITERATION
)
def
count_zeros
(
self
):
def
count_zeros
(
self
):
...
@@ -187,7 +180,7 @@ class MegatronOptimizer(ABC):
...
@@ -187,7 +180,7 @@ class MegatronOptimizer(ABC):
def
step
(
self
,
args
,
timers
):
def
step
(
self
,
args
,
timers
):
pass
pass
def
gather_model_params
(
self
,
args
,
timers
,
ITERATION
):
def
gather_model_params
(
self
,
args
,
timers
):
'''For the case of a non-distributed-optimizer, there is nothing to
'''For the case of a non-distributed-optimizer, there is nothing to
do here.'''
do here.'''
pass
pass
...
@@ -239,9 +232,6 @@ class MegatronOptimizer(ABC):
...
@@ -239,9 +232,6 @@ class MegatronOptimizer(ABC):
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
def
allreduce_embedding_grads
(
self
,
args
):
def
allreduce_embedding_grads
(
self
,
args
):
# >>>
# return # ** .. TEMPORARY .. **
# <<<
self
.
allreduce_word_embedding_grads
(
args
)
self
.
allreduce_word_embedding_grads
(
args
)
self
.
allreduce_position_embedding_grads
(
args
)
self
.
allreduce_position_embedding_grads
(
args
)
...
@@ -260,7 +250,6 @@ class MegatronOptimizer(ABC):
...
@@ -260,7 +250,6 @@ class MegatronOptimizer(ABC):
timers
(
'backward-embedding-all-reduce'
).
stop
()
timers
(
'backward-embedding-all-reduce'
).
stop
()
# class BaseFloat16Optimizer(MegatronOptimizer):
class
MixedPrecisionOptimizer
(
MegatronOptimizer
):
class
MixedPrecisionOptimizer
(
MegatronOptimizer
):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
...
@@ -275,6 +264,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -275,6 +264,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self
.
bf16
=
bf16
self
.
bf16
=
bf16
self
.
grad_scaler
=
grad_scaler
self
.
grad_scaler
=
grad_scaler
# None grad scaler is only supported for bf16.
# None grad scaler is only supported for bf16.
if
self
.
grad_scaler
is
None
:
if
self
.
grad_scaler
is
None
:
assert
self
.
bf16
,
'fp16 expects a grad scaler.'
assert
self
.
bf16
,
'fp16 expects a grad scaler.'
...
@@ -313,7 +303,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -313,7 +303,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# Collect main grads.
# Collect main grads.
main_grads
=
self
.
_collect_main_grad_data_for_unscaling
()
main_grads
=
self
.
_collect_main_grad_data_for_unscaling
()
# pax(1, {"main_grads": main_grads})
# Reset found inf.
# Reset found inf.
self
.
found_inf
.
fill_
(
0.0
)
self
.
found_inf
.
fill_
(
0.0
)
...
@@ -330,25 +319,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -330,25 +319,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# Check for nan.
# Check for nan.
found_inf_flag
=
(
self
.
found_inf
.
item
()
>
0
)
found_inf_flag
=
(
self
.
found_inf
.
item
()
>
0
)
# >>>
# if self.grad_scaler.scale <= 131072:
# pax(0, {
# # "grad_scaler" : self.grad_scaler,
# # "found_inf_flag" : found_inf_flag,
# "model_params" : [
# p
# for m in self.models
# for p in m.parameters()
# ],
# "model_grads" : [
# p.main_grad
# for m in self.models
# for p in m.parameters()
# ],
# # "main_grads" : main_grads,
# })
# <<<
return
found_inf_flag
return
found_inf_flag
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
...
@@ -409,16 +379,11 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -409,16 +379,11 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
step
(
self
,
args
,
timers
,
ITERATION
):
def
step
(
self
,
args
,
timers
):
# >>>
# self.debug_model(ITERATION, "before copy grad.", 0)
# self.debug_main(ITERATION, "before copy grad.", 0)
# <<<
# Copy gradients from model params to main params.
# Copy gradients from model params to main params.
timers
(
'optimizer-copy-to-main-grad'
).
start
()
timers
(
'optimizer-copy-to-main-grad'
).
start
()
self
.
_copy_model_grads_to_main_grads
(
ITERATION
)
self
.
_copy_model_grads_to_main_grads
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# Do unscale, check for inf, and update grad scaler only for
# Do unscale, check for inf, and update grad scaler only for
...
@@ -430,10 +395,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -430,10 +395,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
# >>>
# <<<
# We are done with scaling gradients
# We are done with scaling gradients
# so we can update the loss scale.
# so we can update the loss scale.
self
.
grad_scaler
.
update
(
found_inf_flag
)
self
.
grad_scaler
.
update
(
found_inf_flag
)
...
@@ -446,7 +407,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -446,7 +407,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers
(
'optimizer-clip-main-grad'
).
start
()
timers
(
'optimizer-clip-main-grad'
).
start
()
grad_norm
=
None
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
,
ITERATION
)
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
timers
(
'optimizer-clip-main-grad'
).
stop
()
# count the zeros in the grads
# count the zeros in the grads
...
@@ -458,20 +419,13 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -458,20 +419,13 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# Update params from main params.
# Update params from main params.
timers
(
'optimizer-copy-main-to-model-params'
).
start
()
timers
(
'optimizer-copy-main-to-model-params'
).
start
()
self
.
_copy_main_params_to_model_params
(
ITERATION
)
self
.
_copy_main_params_to_model_params
()
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
# >>>
# self.debug_model(ITERATION, "after copy param.", 0)
# self.debug_main(ITERATION, "after copy param.", 0)
# <<<
# Successful update.
# Successful update.
return
True
,
grad_norm
,
num_zeros_in_grad
return
True
,
grad_norm
,
num_zeros_in_grad
# class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
class
Float16OptimizerWithFloat16Params
(
MixedPrecisionOptimizer
):
class
Float16OptimizerWithFloat16Params
(
MixedPrecisionOptimizer
):
"""Float16 optimizer for fp16 and bf16 data types.
"""Float16 optimizer for fp16 and bf16 data types.
...
@@ -613,7 +567,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
...
@@ -613,7 +567,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
return
model_data
,
main_data
return
model_data
,
main_data
def
_copy_model_grads_to_main_grads
(
self
,
ITERATION
):
def
_copy_model_grads_to_main_grads
(
self
):
# This only needs to be done for the float16 group.
# This only needs to be done for the float16 group.
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
self
.
fp32_from_float16_groups
):
...
@@ -645,7 +599,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
...
@@ -645,7 +599,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
model_param
.
main_grad
=
None
model_param
.
main_grad
=
None
def
_copy_main_params_to_model_params
(
self
,
ITERATION
):
def
_copy_main_params_to_model_params
(
self
):
# Only needed for the float16 params.
# Only needed for the float16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_float16
()
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_float16
()
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
...
@@ -728,7 +682,7 @@ class FP32Optimizer(MegatronOptimizer):
...
@@ -728,7 +682,7 @@ class FP32Optimizer(MegatronOptimizer):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
step
(
self
,
args
,
timers
,
ITERATION
):
def
step
(
self
,
args
,
timers
):
"""Clip gradients (if needed) and step the base optimizer.
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
Always return successful since there is no overflow."""
...
@@ -747,7 +701,7 @@ class FP32Optimizer(MegatronOptimizer):
...
@@ -747,7 +701,7 @@ class FP32Optimizer(MegatronOptimizer):
# Clip gradients.
# Clip gradients.
grad_norm
=
None
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
,
ITERATION
)
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
# count the zeros in the grads
# count the zeros in the grads
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment