Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
22fa9bac
Commit
22fa9bac
authored
Jan 04, 2021
by
mohammad
Browse files
further refactor, tested, and changed master to main
parent
160ba680
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
73 deletions
+74
-73
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+71
-70
megatron/training.py
megatron/training.py
+3
-3
No files found.
megatron/optimizer/optimizer.py
View file @
22fa9bac
...
@@ -45,6 +45,18 @@ def _zero_grad_group_helper(group, set_to_none):
...
@@ -45,6 +45,18 @@ def _zero_grad_group_helper(group, set_to_none):
param
.
grad
.
zero_
()
param
.
grad
.
zero_
()
def
_multi_tensor_copy_this_to_that
(
this
,
that
,
overflow_buf
=
None
):
"""Use multi-tensor-applier to copy values from one list to another."""
if
overflow_buf
:
overflow_buf
.
fill_
(
0
)
else
:
overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
overflow_buf
,
[
this
,
that
],
1.0
)
class
MegatronOptimizer
(
ABC
):
class
MegatronOptimizer
(
ABC
):
...
@@ -127,7 +139,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -127,7 +139,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# ======================
# ======================
# ma
ster
parameter stuff
# ma
in
parameter stuff
# ======================
# ======================
# Three groups of parameters:
# Three groups of parameters:
...
@@ -151,20 +163,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -151,20 +163,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
fp16_params_this_group
.
append
(
param
)
fp16_params_this_group
.
append
(
param
)
# Create a copy
# Create a copy
ma
ster
_param
=
param
.
detach
().
clone
().
float
()
ma
in
_param
=
param
.
detach
().
clone
().
float
()
# Store grads
# Store grads
ma
ster
_param
.
requires_grad
=
True
ma
in
_param
.
requires_grad
=
True
# Copy tensor model parallel attributes.
# Copy tensor model parallel attributes.
mpu
.
copy_tensor_model_parallel_attributes
(
ma
ster
_param
,
mpu
.
copy_tensor_model_parallel_attributes
(
ma
in
_param
,
param
)
param
)
if
hasattr
(
param
,
'shared'
):
if
hasattr
(
param
,
'shared'
):
ma
ster
_param
.
shared
=
param
.
shared
ma
in
_param
.
shared
=
param
.
shared
# Replace the optimizer params with the new fp32 copy.
# Replace the optimizer params with the new fp32 copy.
param_group
[
'params'
][
i
]
=
ma
ster
_param
param_group
[
'params'
][
i
]
=
ma
in
_param
fp32_from_fp16_params_this_group
.
append
(
ma
ster
_param
)
fp32_from_fp16_params_this_group
.
append
(
ma
in
_param
)
# Reset existing state dict key to the new ma
ster
param.
# Reset existing state dict key to the new ma
in
param.
if
param
in
self
.
optimizer
.
state
:
if
param
in
self
.
optimizer
.
state
:
self
.
optimizer
.
state
[
ma
ster
_param
]
\
self
.
optimizer
.
state
[
ma
in
_param
]
\
=
self
.
optimizer
.
state
.
pop
(
param
)
=
self
.
optimizer
.
state
.
pop
(
param
)
# fp32 params.
# fp32 params.
...
@@ -200,43 +212,39 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -200,43 +212,39 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return
self
.
grad_scaler
.
scale
return
self
.
grad_scaler
.
scale
def
_copy_model_grads_to_ma
ster
_grads
(
self
):
def
_copy_model_grads_to_ma
in
_grads
(
self
):
# This only needs to be done for the fp16 group.
# This only needs to be done for the fp16 group.
model_grads
=
[]
model_grads
=
[]
ma
ster
_grads
=
[]
ma
in
_grads
=
[]
for
model_group
,
ma
ster
_group
in
zip
(
self
.
fp16_groups
,
for
model_group
,
ma
in
_group
in
zip
(
self
.
fp16_groups
,
self
.
fp32_from_fp16_groups
):
self
.
fp32_from_fp16_groups
):
for
model_param
,
ma
ster
_param
in
zip
(
model_group
,
ma
ster
_group
):
for
model_param
,
ma
in
_param
in
zip
(
model_group
,
ma
in
_group
):
if
model_param
.
grad
is
not
None
:
if
model_param
.
grad
is
not
None
:
if
ma
ster
_param
.
grad
is
None
:
if
ma
in
_param
.
grad
is
None
:
ma
ster
_param
.
grad
=
torch
.
empty_like
(
ma
ster
_param
)
ma
in
_param
.
grad
=
torch
.
empty_like
(
ma
in
_param
)
model_grads
.
append
(
model_param
.
grad
.
data
)
model_grads
.
append
(
model_param
.
grad
.
data
)
master_grads
.
append
(
master_param
.
grad
.
data
)
main_grads
.
append
(
main_param
.
grad
.
data
)
self
.
_dummy_overflow_buf
.
fill_
(
0
)
_multi_tensor_copy_this_to_that
(
this
=
model_grads
,
that
=
main_grads
,
# Scaling with factor `1.0` is equivalent to copy.
overflow_buf
=
self
.
_dummy_overflow_buf
)
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
self
.
_dummy_overflow_buf
,
[
model_grads
,
master_grads
],
1.0
)
def
_unscale_ma
ster
_grads_and_check_for_nan
(
self
):
def
_unscale_ma
in
_grads_and_check_for_nan
(
self
):
ma
ster
_grads
=
[]
ma
in
_grads
=
[]
# fp32 params fromm fp16 ones.
# fp32 params fromm fp16 ones.
for
ma
ster
_group
in
self
.
fp32_from_fp16_groups
:
for
ma
in
_group
in
self
.
fp32_from_fp16_groups
:
for
ma
ster
_param
in
ma
ster
_group
:
for
ma
in
_param
in
ma
in
_group
:
if
ma
ster
_param
.
grad
is
not
None
:
if
ma
in
_param
.
grad
is
not
None
:
ma
ster
_grads
.
append
(
ma
ster
_param
.
grad
.
data
)
ma
in
_grads
.
append
(
ma
in
_param
.
grad
.
data
)
# Append fp32 parameters.
# Append fp32 parameters.
for
ma
ster
_group
in
self
.
fp32_from_fp32_groups
:
for
ma
in
_group
in
self
.
fp32_from_fp32_groups
:
for
ma
ster
_param
in
ma
ster
_group
:
for
ma
in
_param
in
ma
in
_group
:
if
ma
ster
_param
.
grad
is
not
None
:
if
ma
in
_param
.
grad
is
not
None
:
ma
ster
_grads
.
append
(
ma
ster
_param
.
grad
.
data
)
ma
in
_grads
.
append
(
ma
in
_param
.
grad
.
data
)
# Reset found inf.
# Reset found inf.
self
.
found_inf
.
fill_
(
0.0
)
self
.
found_inf
.
fill_
(
0.0
)
# Unscale and set found inf/nan
# Unscale and set found inf/nan
torch
.
_amp_foreach_non_finite_check_and_unscale_
(
torch
.
_amp_foreach_non_finite_check_and_unscale_
(
ma
ster
_grads
,
self
.
found_inf
,
self
.
grad_scaler
.
inv_scale
)
ma
in
_grads
,
self
.
found_inf
,
self
.
grad_scaler
.
inv_scale
)
# Update across all model parallel instances.
# Update across all model parallel instances.
torch
.
distributed
.
all_reduce
(
self
.
found_inf
,
torch
.
distributed
.
all_reduce
(
self
.
found_inf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
...
@@ -247,40 +255,33 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -247,40 +255,33 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return
found_inf_flag
return
found_inf_flag
def
_get_model_and_ma
ster
_params_data_fp16
(
self
):
def
_get_model_and_ma
in
_params_data_fp16
(
self
):
model_data
=
[]
model_data
=
[]
ma
ster
_data
=
[]
ma
in
_data
=
[]
for
model_group
,
ma
ster
_group
in
zip
(
self
.
fp16_groups
,
for
model_group
,
ma
in
_group
in
zip
(
self
.
fp16_groups
,
self
.
fp32_from_fp16_groups
):
self
.
fp32_from_fp16_groups
):
for
model_param
,
ma
ster
_param
in
zip
(
model_group
,
ma
ster
_group
):
for
model_param
,
ma
in
_param
in
zip
(
model_group
,
ma
in
_group
):
model_data
.
append
(
model_param
.
data
)
model_data
.
append
(
model_param
.
data
)
ma
ster
_data
.
append
(
ma
ster
_param
.
data
)
ma
in
_data
.
append
(
ma
in
_param
.
data
)
return
model_data
,
ma
ster
_data
return
model_data
,
ma
in
_data
def
_copy_ma
ster
_params_to_model_params
(
self
):
def
_copy_ma
in
_params_to_model_params
(
self
):
# Only needed for the fp16 params.
# Only needed for the fp16 params.
model_data
,
master_data
=
self
.
_get_model_and_master_params_data_fp16
()
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_fp16
()
self
.
_dummy_overflow_buf
.
fill_
(
0
)
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
# Scaling with factor `1.0` is equivalent to copy.
overflow_buf
=
self
.
_dummy_overflow_buf
)
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
self
.
_dummy_overflow_buf
,
[
master_data
,
model_data
],
def
_copy_model_params_to_main_params
(
self
):
1.0
)
def
_copy_model_params_to_master_params
(
self
):
# Only needed for the fp16 params.
# Only needed for the fp16 params.
model_data
,
master_data
=
self
.
_get_model_and_master_params_data_fp16
()
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_fp16
()
self
.
_dummy_overflow_buf
.
fill_
(
0
)
_multi_tensor_copy_this_to_that
(
this
=
model_data
,
that
=
main_data
,
# Scaling with factor `1.0` is equivalent to copy.
overflow_buf
=
self
.
_dummy_overflow_buf
)
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
self
.
_dummy_overflow_buf
,
[
model_data
,
master_data
],
1.0
)
def
reload_model_params
(
self
):
def
reload_model_params
(
self
):
self
.
_copy_model_params_to_ma
ster
_params
()
self
.
_copy_model_params_to_ma
in
_params
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -289,17 +290,17 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -289,17 +290,17 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
timers
=
get_timers
()
timers
=
get_timers
()
# ==================================================
# ==================================================
# Copy gradients from model params to ma
ster
params.
# Copy gradients from model params to ma
in
params.
# ==================================================
# ==================================================
timers
(
'optimizer-copy-to-ma
ster
-grad'
).
start
()
timers
(
'optimizer-copy-to-ma
in
-grad'
).
start
()
self
.
_copy_model_grads_to_ma
ster
_grads
()
self
.
_copy_model_grads_to_ma
in
_grads
()
timers
(
'optimizer-copy-to-ma
ster
-grad'
).
stop
()
timers
(
'optimizer-copy-to-ma
in
-grad'
).
stop
()
# ==============================
# ==============================
# Unscale and check for inf/nan.
# Unscale and check for inf/nan.
# ==============================
# ==============================
timers
(
'optimizer-unscale-and-check-inf'
).
start
()
timers
(
'optimizer-unscale-and-check-inf'
).
start
()
found_inf_flag
=
self
.
_unscale_ma
ster
_grads_and_check_for_nan
()
found_inf_flag
=
self
.
_unscale_ma
in
_grads_and_check_for_nan
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
# ==================================
# ==================================
...
@@ -315,11 +316,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -315,11 +316,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return
False
return
False
# ==========================
# ==========================
# Clip the ma
ster
gradients.
# Clip the ma
in
gradients.
# ==========================
# ==========================
timers
(
'optimizer-clip-ma
ster
-grad'
).
start
()
timers
(
'optimizer-clip-ma
in
-grad'
).
start
()
self
.
clip_grad_norm
(
self
.
clip_grad
)
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-ma
ster
-grad'
).
stop
()
timers
(
'optimizer-clip-ma
in
-grad'
).
stop
()
# ===================
# ===================
# Step the optimizer.
# Step the optimizer.
...
@@ -327,11 +328,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -327,11 +328,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
# =================================
# =================================
# Update params from ma
ster
params.
# Update params from ma
in
params.
# =================================
# =================================
timers
(
'optimizer-copy-ma
ster
-to-model-params'
).
start
()
timers
(
'optimizer-copy-ma
in
-to-model-params'
).
start
()
self
.
_copy_ma
ster
_params_to_model_params
()
self
.
_copy_ma
in
_params_to_model_params
()
timers
(
'optimizer-copy-ma
ster
-to-model-params'
).
stop
()
timers
(
'optimizer-copy-ma
in
-to-model-params'
).
stop
()
# ==================
# ==================
# Successful update.
# Successful update.
...
@@ -363,7 +364,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -363,7 +364,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
else
:
else
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
# Copy data for the ma
ster
params.
# Copy data for the ma
in
params.
fp32_from_fp16_params_key
=
'fp32_from_fp16_params'
fp32_from_fp16_params_key
=
'fp32_from_fp16_params'
if
fp32_from_fp16_params_key
not
in
state_dict
:
if
fp32_from_fp16_params_key
not
in
state_dict
:
fp32_from_fp16_params_key
=
'fp32_from_fp16'
fp32_from_fp16_params_key
=
'fp32_from_fp16'
...
...
megatron/training.py
View file @
22fa9bac
...
@@ -677,10 +677,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -677,10 +677,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging
(
'backward-send-forward-recv'
)
add_to_logging
(
'backward-send-forward-recv'
)
add_to_logging
(
'backward-params-all-reduce'
)
add_to_logging
(
'backward-params-all-reduce'
)
add_to_logging
(
'backward-embedding-all-reduce'
)
add_to_logging
(
'backward-embedding-all-reduce'
)
add_to_logging
(
'optimizer-copy-to-ma
ster
-grad'
)
add_to_logging
(
'optimizer-copy-to-ma
in
-grad'
)
add_to_logging
(
'optimizer-unscale-and-check-inf'
)
add_to_logging
(
'optimizer-unscale-and-check-inf'
)
add_to_logging
(
'optimizer-clip-ma
ster
-grad'
)
add_to_logging
(
'optimizer-clip-ma
in
-grad'
)
add_to_logging
(
'optimizer-copy-ma
ster
-to-model-params'
)
add_to_logging
(
'optimizer-copy-ma
in
-to-model-params'
)
add_to_logging
(
'optimizer'
)
add_to_logging
(
'optimizer'
)
add_to_logging
(
'batch-generator'
)
add_to_logging
(
'batch-generator'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment