Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
c6a58e41
Commit
c6a58e41
authored
Dec 24, 2020
by
mohammad
Browse files
working on the new optimizer
parent
2348c99a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
401 additions
and
9 deletions
+401
-9
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+378
-0
megatron/training.py
megatron/training.py
+23
-9
No files found.
megatron/optimizer/optimizer.py
0 → 100644
View file @
c6a58e41
from
abc
import
ABC
from
abc
import
abstractmethod
import
torch
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
from
megatron
import
mpu
from
megatron
import
get_args
def
get_megatron_optimizer
(
optimizer
):
args
=
get_args
()
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
2
**
32
,
min_scale
=
args
.
min_scale
,
growth_factor
=
2.0
,
backoff_factor
=
0.5
,
growth_interval
=
args
.
loss_scale_window
,
hysteresis
=
args
.
hysteresis
)
megatron_optimizer
=
FP16OptimizerWithFP16Params
(
optimizer
,
grad_scaler
,
args
.
clip_grad
)
return
megatron_optimizer
class
MegatronGradScaler
(
ABC
):
def
__init__
(
self
,
initial_scale
):
"""Initialize scale value with the input initial scale."""
assert
initial_scale
>
0.0
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
initial_scale
])
@
property
def
scale
(
self
):
return
self
.
_scale
@
property
def
inv_scale
(
self
):
return
self
.
_scale
.
double
().
reciprocal
().
float
()
@
abstractmethod
def
update
(
self
,
found_inf
):
pass
'''
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
'''
class
ConstantGradScaler
(
MegatronGradScaler
):
pass
class
DynamicGradScaler
(
MegatronGradScaler
):
def
__init__
(
self
,
initial_scale
,
min_scale
,
growth_factor
,
backoff_factor
,
growth_interval
,
hysteresis
):
""""Grad scaler with dynamic scale that gets adjusted
during training."""
super
(
DynamicGradScaler
,
self
).
__init__
(
initial_scale
)
# Lower bound on the scale.
assert
min_scale
>
0.0
assert
min_scale
<=
initial_scale
self
.
min_scale
=
torch
.
cuda
.
FloatTensor
([
min_scale
])
# Growth and backoff factors for the scale.
assert
growth_factor
>
1.0
self
.
growth_factor
=
torch
.
cuda
.
FloatTensor
([
growth_factor
])
assert
backoff_factor
<
1.0
assert
backoff_factor
>
0.0
self
.
backoff_factor
=
torch
.
cuda
.
FloatTensor
([
backoff_factor
])
# Interval over which if we don't see any inf/nan,
# we will scale the grad scale by the growth factor.
assert
growth_interval
>
0
self
.
growth_interval
=
growth_interval
# Number of inf/nans we should see before scaling down
# the grad scale by the backoff factor.
assert
hysteresis
>
0
self
.
hysteresis
=
hysteresis
# Trackers.
self
.
_growth_tracker
=
0
self
.
_hysteresis_tracker
=
self
.
hysteresis
def
update
(
self
,
found_inf
):
# If we have an inf/nan, growth tracker is set to 0
# and hysterisis tracker is reduced by 1.
if
found_inf
:
self
.
_growth_tracker
=
0
self
.
_hysteresis_tracker
-=
1
# Now if we are our of hysteresis count, scale down the loss.
if
self
.
_hysteresis_tracker
<=
0
:
self
.
_scale
=
torch
.
max
(
self
.
_scale
*
self
.
backoff_factor
,
self
.
min_scale
)
else
:
# If there is no nan/inf, increment the growth tracker.
self
.
_growth_tracker
+=
1
# If we have had enough consequitive intervals with no nan/inf:
if
self
.
_growth_tracker
==
self
.
growth_interval
:
# Reset the tracker and hysteresis trackers,
self
.
_growth_tracker
=
0
self
.
_hysteresis_tracker
=
self
.
hysteresis
# and scale up the loss scale.
self
.
_scale
=
self
.
_scale
*
self
.
growth_factor
def
_zero_grad_group_helper
(
group
,
set_to_none
):
"""Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer."""
for
param
in
group
:
if
param
.
grad
is
not
None
:
if
set_to_none
:
param
.
grad
=
None
else
:
if
param
.
grad
.
grad_fn
is
not
None
:
param
.
grad
.
detach_
()
else
:
param
.
grad
.
requires_grad_
(
False
)
param
.
grad
.
zero_
()
class
MegatronOptimizer
(
ABC
):
def
__init__
(
self
,
optimizer
):
"""Input optimizer is the base optimizer for example Adam."""
self
.
optimizer
=
optimizer
assert
self
.
optimizer
,
'no optimizer is provided.'
@
abstractmethod
def
zero_grad
(
self
,
set_to_none
=
True
):
pass
@
abstractmethod
def
get_loss_scale
(
self
):
pass
def
scale_loss
(
self
,
loss
):
"""Simple scaling."""
return
self
.
get_loss_scale
()
*
loss
@
abstractmethod
def
step
(
self
):
pass
'''
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
'''
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
def
_get_state
(
self
):
return
self
.
optimizer
.
state
def
_set_state
(
self
,
value
):
self
.
optimizer
.
state
=
value
state
=
property
(
_get_state
,
_set_state
)
# Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def
_get_param_groups
(
self
):
return
self
.
optimizer
.
param_groups
def
_set_param_groups
(
self
,
value
):
self
.
optimizer
.
param_groups
=
value
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
class
FP16OptimizerWithFP16Params
(
MegatronOptimizer
):
def
__init__
(
self
,
optimizer
,
grad_scaler
,
clip_grad
):
super
(
FP16OptimizerWithFP16Params
,
self
).
__init__
(
optimizer
)
self
.
grad_scaler
=
grad_scaler
self
.
clip_grad
=
clip_grad
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
self
.
found_inf
=
torch
.
cuda
.
FloatTensor
([
0.0
])
# Dummy tensor needed for apex multi-apply tensor.
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# ======================
# master parameter stuff
# ======================
# Three groups of parameters:
# fp16_groups: original fp16 parameters
# fp32_from_fp16_groups: fp32 copy of fp16 parameters
# fp32_from_fp32_groups: original fp32 parameters
self
.
fp16_groups
=
[]
self
.
fp32_from_fp16_groups
=
[]
self
.
fp32_from_fp32_groups
=
[]
# For all the groups in the original optimizer:
for
param_group
in
self
.
optimizer
.
param_groups
:
fp16_params_this_group
=
[]
fp32_params_this_group
=
[]
fp32_from_fp16_params_this_group
=
[]
# For all the parameters in this group:
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
if
param
.
requires_grad
:
# fp16 params:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
fp16_params_this_group
.
append
(
param
)
# Create a copy
master_param
=
param
.
detach
().
clone
().
float
()
# Store grads
master_param
.
requires_grad
=
True
# Copy tensor model parallel attributes.
master_param
.
tensor_model_parallel
=
param
.
tensor_model_parallel
#mpu.copy_tensor_model_parallel_attributes(master_param,
# param)
# Replace the optimizer params with the new fp32 copy.
param_group
[
'params'
][
i
]
=
master_param
fp32_from_fp16_params_this_group
.
append
(
master_param
)
# Reset existing state dict key to the new master param.
if
param
in
self
.
optimizer
.
state
:
self
.
optimizer
.
state
[
master_param
]
\
=
self
.
optimizer
.
state
.
pop
(
param
)
# fp32 params.
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
fp32_params_this_group
.
append
(
param
)
param_group
[
'params'
][
i
]
=
param
else
:
raise
TypeError
(
"Wrapped parameters must be either "
"torch.cuda.FloatTensor or "
"torch.cuda.HalfTensor. "
"Received {}"
.
format
(
param
.
type
()))
self
.
fp16_groups
.
append
(
fp16_params_this_group
)
self
.
fp32_from_fp16_groups
.
append
(
fp32_from_fp16_params_this_group
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
fp16_groups & fp32_from_fp32_groups."""
for
group
in
self
.
fp16_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
for
group
in
self
.
fp32_from_fp32_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
def
get_loss_scale
(
self
):
return
self
.
grad_scaler
.
scale
@
torch
.
no_grad
()
def
step
(
self
):
# ==================================================
# Copy gradients from model params to master params.
# ==================================================
# This only needs to be done for the fp16 group.
model_grads
=
[]
master_grads
=
[]
for
model_group
,
master_group
in
zip
(
self
.
fp16_groups
,
self
.
fp32_from_fp16_groups
):
for
model_param
,
master_param
in
zip
(
model_group
,
master_group
):
if
model_param
.
grad
is
not
None
:
if
master_param
.
grad
is
None
:
master_param
.
grad
=
torch
.
empty_like
(
master_param
)
model_grads
.
append
(
model_param
.
grad
)
master_grads
.
append
(
master_param
.
grad
)
self
.
_dummy_overflow_buf
.
fill_
(
0
)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
self
.
_dummy_overflow_buf
,
[
model_grads
,
master_grads
],
1.0
)
# ==============================
# Unscale and check for inf/nan.
# ==============================
# Append fp32 parameters.
for
master_group
in
self
.
fp32_from_fp32_groups
:
for
master_param
in
master_group
:
if
master_param
.
grad
is
not
None
:
master_grads
.
append
(
master_param
.
grad
)
# Reset found inf.
self
.
found_inf
.
fill_
(
0.0
)
# Unscale and set found inf/nan
torch
.
_amp_foreach_non_finite_check_and_unscale_
(
master_grads
,
self
.
found_inf
,
self
.
grad_scaler
.
inv_scale
)
# Update across all model parallel instances.
torch
.
distributed
.
all_reduce
(
self
.
found_inf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
mpu
.
get_model_parallel_group
())
# ==================================
# We are done with scaling gradients
# so we can update the loss scale.
# ==================================
found_inf_flag
=
(
self
.
found_inf
.
item
()
>
0
)
self
.
grad_scaler
.
update
(
found_inf_flag
)
# =====================================
# If we found inf/nan, skip the update.
# =====================================
if
found_inf_flag
:
return
False
# ==========================
# Clip the master gradients.
# ==========================
fp32_params
=
[]
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
fp32_params
.
append
(
param
)
mpu
.
clip_grad_norm
(
fp32_params
,
self
.
clip_grad
)
# ===================
# Step the optimizer.
# ===================
self
.
optimizer
.
step
()
# =================================
# Update params from master params.
# =================================
# Only needed for the fp16 params.
model_data
=
[]
master_data
=
[]
for
model_group
,
master_group
in
zip
(
self
.
fp16_groups
,
self
.
fp32_from_fp16_groups
):
for
model_param
,
master_param
in
zip
(
model_group
,
master_group
):
model_data
.
append
(
model_param
.
data
)
master_data
.
append
(
master_param
.
data
)
self
.
_dummy_overflow_buf
.
fill_
(
0
)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
self
.
_dummy_overflow_buf
,
[
master_data
,
model_data
],
1.0
)
return
True
megatron/training.py
View file @
c6a58e41
...
@@ -39,7 +39,9 @@ from megatron import print_rank_last
...
@@ -39,7 +39,9 @@ from megatron import print_rank_last
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.fp16
import
FP16_Module
from
megatron.fp16
import
FP16_Module
from
megatron.fp16
import
FP16_Optimizer
#from megatron.fp16 import FP16_Optimizer
from
megatron.optimizer.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.learning_rates
import
AnnealingLR
from
megatron.learning_rates
import
AnnealingLR
...
@@ -232,6 +234,8 @@ def get_optimizer(model):
...
@@ -232,6 +234,8 @@ def get_optimizer(model):
# Wrap into fp16 optimizer.
# Wrap into fp16 optimizer.
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
=
get_megatron_optimizer
(
optimizer
)
'''
optimizer = FP16_Optimizer(optimizer,
optimizer = FP16_Optimizer(optimizer,
static_loss_scale=args.loss_scale,
static_loss_scale=args.loss_scale,
dynamic_loss_scale=args.dynamic_loss_scale,
dynamic_loss_scale=args.dynamic_loss_scale,
...
@@ -239,7 +243,7 @@ def get_optimizer(model):
...
@@ -239,7 +243,7 @@ def get_optimizer(model):
'scale_window': args.loss_scale_window,
'scale_window': args.loss_scale_window,
'min_scale': args.min_scale,
'min_scale': args.min_scale,
'delayed_shift': args.hysteresis})
'delayed_shift': args.hysteresis})
'''
return
optimizer
return
optimizer
...
@@ -367,12 +371,16 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g
...
@@ -367,12 +371,16 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g
input_tensor
.
retain_grad
()
input_tensor
.
retain_grad
()
# Backward pass.
# Backward pass.
if
args
.
fp16
:
if
output_tensor_grad
is
None
:
output_tensor
=
optimizer
.
scale_loss
(
output_tensor
)
torch
.
autograd
.
backward
(
output_tensor
,
grad_tensors
=
output_tensor_grad
)
'''
if args.fp16 and output_tensor_grad is None:
optimizer.backward(output_tensor, update_master_grads=False,
optimizer.backward(output_tensor, update_master_grads=False,
output_tensor_grad=output_tensor_grad)
output_tensor_grad=output_tensor_grad)
else:
else:
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
'''
# Collect the grad of the input_tensor.
# Collect the grad of the input_tensor.
input_tensor_grad
=
None
input_tensor_grad
=
None
if
input_tensor
is
not
None
:
if
input_tensor
is
not
None
:
...
@@ -590,10 +598,13 @@ def train_step(forward_step_func, data_iterator,
...
@@ -590,10 +598,13 @@ def train_step(forward_step_func, data_iterator,
timers
=
get_timers
()
timers
=
get_timers
()
# Set grad to zero.
# Set grad to zero.
optimizer
.
zero_grad
()
'''
if args.fp16:
if args.fp16:
optimizer.zero_grad(set_grads_to_None=True)
optimizer.zero_grad(set_grads_to_None=True)
else:
else:
optimizer.zero_grad()
optimizer.zero_grad()
'''
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
losses_reduced
=
forward_backward_pipelining
(
losses_reduced
=
forward_backward_pipelining
(
...
@@ -627,12 +638,14 @@ def train_step(forward_step_func, data_iterator,
...
@@ -627,12 +638,14 @@ def train_step(forward_step_func, data_iterator,
timers
(
'backward-embedding-all-reduce'
).
stop
()
timers
(
'backward-embedding-all-reduce'
).
stop
()
# Update master gradients.
# Update master gradients.
'''
timers('backward-master-grad').start()
timers('backward-master-grad').start()
if args.fp16:
if args.fp16:
optimizer.update_master_grads()
optimizer.update_master_grads()
timers('backward-master-grad').stop()
timers('backward-master-grad').stop()
'''
# Clipping gradients helps prevent the exploding gradient.
# Clipping gradients helps prevent the exploding gradient.
'''
timers('backward-clip-grad').start()
timers('backward-clip-grad').start()
if args.clip_grad > 0.:
if args.clip_grad > 0.:
if not args.fp16:
if not args.fp16:
...
@@ -647,19 +660,20 @@ def train_step(forward_step_func, data_iterator,
...
@@ -647,19 +660,20 @@ def train_step(forward_step_func, data_iterator,
else:
else:
optimizer.clip_master_grads(args.clip_grad)
optimizer.clip_master_grads(args.clip_grad)
timers('backward-clip-grad').stop()
timers('backward-clip-grad').stop()
'''
# Update parameters.
# Update parameters.
timers
(
'optimizer'
).
start
()
timers
(
'optimizer'
).
start
()
optimizer
.
step
()
update_successfull
=
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
timers
(
'optimizer'
).
stop
()
# Update learning rate.
# Update learning rate.
skipped_iter
=
0
if
update_successfull
:
if
not
(
args
.
fp16
and
optimizer
.
overflow
):
increment
=
get_num_microbatches
()
*
\
increment
=
get_num_microbatches
()
*
\
args
.
micro_batch_size
*
\
args
.
micro_batch_size
*
\
args
.
data_parallel_size
args
.
data_parallel_size
lr_scheduler
.
step
(
increment
=
increment
)
lr_scheduler
.
step
(
increment
=
increment
)
skipped_iter
=
0
else
:
else
:
skipped_iter
=
1
skipped_iter
=
1
...
@@ -845,7 +859,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -845,7 +859,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Logging.
# Logging.
loss_scale
=
None
loss_scale
=
None
if
args
.
fp16
:
if
args
.
fp16
:
loss_scale
=
optimizer
.
loss_scale
loss_scale
=
optimizer
.
get_
loss_scale
().
item
()
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
optimizer
.
param_groups
[
0
][
'lr'
],
optimizer
.
param_groups
[
0
][
'lr'
],
iteration
,
loss_scale
,
iteration
,
loss_scale
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment