Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
1a2cb60c
Commit
1a2cb60c
authored
Mar 08, 2021
by
Jared Casper
Browse files
Merge branch 'bfloat_with_fp32_grad_acc' into 'main'
Bfloat with fp32 grad acc See merge request ADLR/megatron-lm!247
parents
87b8b9dc
b4bc51b1
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
504 additions
and
211 deletions
+504
-211
megatron/arguments.py
megatron/arguments.py
+25
-4
megatron/model/__init__.py
megatron/model/__init__.py
+5
-3
megatron/model/bert_model.py
megatron/model/bert_model.py
+1
-1
megatron/model/distributed.py
megatron/model/distributed.py
+178
-72
megatron/model/module.py
megatron/model/module.py
+26
-12
megatron/model/transformer.py
megatron/model/transformer.py
+15
-2
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+35
-14
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+174
-83
megatron/training.py
megatron/training.py
+33
-15
megatron/utils.py
megatron/utils.py
+12
-5
No files found.
megatron/arguments.py
View file @
1a2cb60c
...
...
@@ -129,11 +129,26 @@ def parse_args(extra_args_provider=None, defaults={},
# Parameters dtype.
args
.
params_dtype
=
torch
.
float
if
args
.
fp16
:
assert
not
args
.
bf16
args
.
params_dtype
=
torch
.
half
if
args
.
bf16
:
assert
not
args
.
fp16
args
.
params_dtype
=
torch
.
bfloat16
# No fusion is support for bfloat for now
assert
not
args
.
masked_softmax_fusion
assert
not
args
.
bias_gelu_fusion
assert
not
args
.
bias_dropout_fusion
if
args
.
rank
==
0
:
print
(
'using {} for parameters ...'
.
format
(
args
.
params_dtype
),
flush
=
True
)
# If we do accumulation and all-reduces in fp32, we need to have
# local DDP and we should set the use-contiguous-buffers-in-ddp.
if
args
.
accumulate_allreduce_grads_in_fp32
:
assert
args
.
DDP_impl
==
'local'
args
.
use_contiguous_buffers_in_ddp
=
True
if
args
.
dataloader_type
is
None
:
args
.
dataloader_type
=
'single'
...
...
@@ -204,8 +219,8 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
fp16_lm_cross_entropy
:
assert
args
.
fp16
,
'lm cross entropy in fp16 only support in fp16 mode.'
if
args
.
fp32_residual_connection
:
assert
args
.
fp16
,
\
'residual connection in fp32 only supported when using fp16.'
assert
args
.
fp16
or
args
.
bf16
,
\
'residual connection in fp32 only supported when using fp16
or bf16
.'
# Activation checkpointing.
if
args
.
distribute_checkpointed_activations
:
assert
args
.
checkpoint_activations
,
\
...
...
@@ -528,6 +543,8 @@ def _add_mixed_precision_args(parser):
group
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'Run model in fp16 mode.'
)
group
.
add_argument
(
'--bf16'
,
action
=
'store_true'
,
help
=
'Run model in bfloat16 mode.'
)
group
.
add_argument
(
'--loss-scale'
,
type
=
float
,
default
=
None
,
help
=
'Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic'
...
...
@@ -549,8 +566,9 @@ def _add_mixed_precision_args(parser):
help
=
'Run attention masking and softmax in fp32. '
'This flag is ignored unless '
'--no-query-key-layer-scaling is specified.'
)
group
.
add_argument
(
'--fp32-allreduce'
,
action
=
'store_true'
,
help
=
'All-reduce in fp32'
)
group
.
add_argument
(
'--accumulate-allreduce-grads-in-fp32'
,
action
=
'store_true'
,
help
=
'Gradient accumulation and all-reduce in fp32.'
)
group
.
add_argument
(
'--fp16-lm-cross-entropy'
,
action
=
'store_true'
,
help
=
'Move the cross entropy unreduced loss calculation'
'for lm head to fp16.'
)
...
...
@@ -577,6 +595,9 @@ def _add_distributed_args(parser):
choices
=
[
'local'
,
'torch'
],
help
=
'which DistributedDataParallel implementation '
'to use.'
)
group
.
add_argument
(
'--use-contiguous-buffers-in-ddp'
,
action
=
'store_true'
,
help
=
'If set, use contiguous buffer in DDP. Note that '
'this option only works woth local DDP.'
)
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
dest
=
'scatter_gather_tensors_in_pipeline'
)
...
...
megatron/model/__init__.py
View file @
1a2cb60c
...
...
@@ -16,11 +16,13 @@
_LAYER_NORM
=
None
def
import_layernorm
(
fp32_residual_connection
):
def
import_layernorm
(
fp32_residual_connection
,
bf16
):
global
_LAYER_NORM
if
not
_LAYER_NORM
:
if
fp32_residual_connection
:
if
bf16
:
from
torch.nn
import
LayerNorm
elif
fp32_residual_connection
:
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
else
:
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
as
LayerNorm
...
...
@@ -39,6 +41,6 @@ from .gpt_model import (GPTModel,
GPTModelIntermediateStage
,
GPTModelLastStage
)
from
.language_model
import
get_language_model
from
.module
import
F
P
16Module
from
.module
import
F
loat
16Module
megatron/model/bert_model.py
View file @
1a2cb60c
...
...
@@ -78,7 +78,7 @@ class BertLMHead(MegatronModule):
self
.
parallel_output
=
parallel_output
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
,
args
.
bf16
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
if
args
.
openai_gelu
:
...
...
megatron/model/distributed.py
View file @
1a2cb60c
...
...
@@ -13,100 +13,206 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
abc
import
ABC
from
abc
import
abstractmethod
import
torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
import
torch.distributed
as
dist
from
torch.nn.modules
import
Module
from
torch.autograd
import
Variable
from
megatron
import
get_args
from
megatron
import
mpu
from
.module
import
MegatronModule
class
DistributedDataParallel
(
MegatronModule
):
def
__init__
(
self
,
module
):
super
(
DistributedDataParallel
,
self
).
__init__
()
self
.
warn_on_half
=
True
if
dist
.
_backend
==
dist
.
dist_backend
.
GLOO
else
False
class
MemoryBuffer
:
def
__init__
(
self
,
numel
,
dtype
):
self
.
numel
=
numel
self
.
dtype
=
dtype
self
.
data
=
torch
.
zeros
(
self
.
numel
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
def
zero
(
self
):
"""Reset the buffer to zero."""
self
.
data
.
zero_
()
def
get
(
self
,
shape
,
start_index
):
"""Return a tensor with the input `shape` as a view into the
1-D data starting at `start_index`."""
end_index
=
start_index
+
shape
.
numel
()
assert
end_index
<=
self
.
numel
,
\
'requested tensor is out of the buffer range.'
buffer_tensor
=
self
.
data
[
start_index
:
end_index
]
buffer_tensor
=
buffer_tensor
.
view
(
shape
)
return
buffer_tensor
class
DistributedDataParallelBase
(
MegatronModule
,
ABC
):
"""Abstract class for DDP."""
def
__init__
(
self
,
module
):
super
(
DistributedDataParallelBase
,
self
).
__init__
()
# Keep a pointer to the model.
self
.
module
=
module
self
.
data_parallel_group
=
mpu
.
get_data_parallel_group
()
def
allreduce_params
(
reduce_after
=
True
,
no_scale
=
False
,
fp32_allreduce
=
False
):
if
(
self
.
needs_reduction
):
self
.
needs_reduction
=
False
buckets
=
{}
for
name
,
param
in
self
.
module
.
named_parameters
():
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
tp
=
(
param
.
data
.
type
())
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
param
)
if
self
.
warn_on_half
:
if
torch
.
cuda
.
HalfTensor
in
buckets
:
print
(
"WARNING: gloo dist backend for half parameters may be extremely slow."
+
" It is recommended to use the NCCL backend in this case."
)
self
.
warn_on_half
=
False
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
if
fp32_allreduce
:
coalesced
=
coalesced
.
float
()
if
not
no_scale
and
not
reduce_after
:
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
dist
.
all_reduce
(
coalesced
,
group
=
self
.
data_parallel_group
)
torch
.
cuda
.
synchronize
()
if
not
no_scale
and
reduce_after
:
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
self
.
hook_handles
=
[]
self
.
hooks
=
[]
for
param
in
list
(
self
.
module
.
parameters
()):
def
allreduce_hook
(
*
unused
):
Variable
.
_execution_engine
.
queue_callback
(
allreduce_params
)
# handle = param.register_hook(allreduce_hook)
# self.hooks.append(allreduce_hook)
# self.hook_handles.append(handle)
self
.
allreduce_params
=
allreduce_params
@
abstractmethod
def
allreduce_gradients
(
self
):
pass
def
forward
(
self
,
*
inputs
,
**
kwargs
):
self
.
needs_reduction
=
True
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
#[h.remove() for h in self.hook_handles]
sd
=
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
# for handle, hook in zip(self.hook_handles, self.hooks):
# d = handle.hooks_dict_ref()
# d[handle.id] = hook
return
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
sd
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
return
self
.
module
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
'''
def _sync_buffers(self):
buffers = list(self.module._all_buffers())
if len(buffers) > 0:
# cross-node buffer sync
flat_buffers = _flatten_dense_tensors(buffers)
dist.broadcast(flat_buffers, 0)
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
buf.copy_(synced)
def train(self, mode=True):
# Clear NCCL communicator and CUDA event cache of the default group ID,
# These cache will be recreated at the later call. This is currently a
# work-around for a potential NCCL deadlock.
if dist._backend == dist.dist_backend.NCCL:
dist._clear_group_cache()
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
'''
class
DistributedDataParallel
(
DistributedDataParallelBase
):
"""DDP with contiguous buffers options to storre and accumulate gradients.
This class:
- has the potential to reduce memory fragmentation.
- provides the option to do the gradient accumulation
in a type other than the params type (for example fp32)
Arguments:
module: input model.
accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
and the gradient all-reduce all in in float32. If this option is
true, we require `use_contiguous_buffers` to be true too.
use_contiguous_buffers: if true, use a contiguous buffer to store the
gradients.
"""
def
__init__
(
self
,
module
,
accumulate_allreduce_grads_in_fp32
,
use_contiguous_buffers
):
super
(
DistributedDataParallel
,
self
).
__init__
(
module
)
self
.
accumulate_allreduce_grads_in_fp32
\
=
accumulate_allreduce_grads_in_fp32
self
.
use_contiguous_buffers
=
use_contiguous_buffers
# If we are using fp32-accumulate-allreduce explicitly
# this means we need main grads in a continous buffer.
if
self
.
accumulate_allreduce_grads_in_fp32
:
assert
self
.
use_contiguous_buffers
# ===================================
# Rest of this part applies only to
# the case we use continuous buffers.
# ===================================
self
.
_grad_buffers
=
None
if
self
.
use_contiguous_buffers
:
self
.
_grad_buffers
=
{}
# Simple function to define buffer type.
def
_get_buffer_type
(
param
):
return
torch
.
float
if
\
self
.
accumulate_allreduce_grads_in_fp32
else
param
.
dtype
# First calculate total number of elements per type.
type_num_elements
=
{}
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
dtype
=
_get_buffer_type
(
param
)
type_num_elements
[
dtype
]
=
type_num_elements
.
get
(
dtype
,
0
)
\
+
param
.
data
.
nelement
()
# Allocate the buffer.
for
dtype
,
num_elements
in
type_num_elements
.
items
():
self
.
_grad_buffers
[
dtype
]
=
MemoryBuffer
(
num_elements
,
dtype
)
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
dtype
=
_get_buffer_type
(
param
)
type_num_elements
[
dtype
]
-=
param
.
data
.
nelement
()
param
.
main_grad
=
self
.
_grad_buffers
[
dtype
].
get
(
param
.
data
.
shape
,
type_num_elements
[
dtype
])
# Backward hook.
# Accumalation function for the gradients. We need
# to store them so they don't go out of scope.
self
.
grad_accs
=
[]
# Loop over all the parameters in the model.
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
# Expand so we get access to grad_fn.
param_tmp
=
param
.
expand_as
(
param
)
# Get the gradient accumulator functtion.
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
grad_acc
.
register_hook
(
self
.
_make_param_hook
(
param
))
self
.
grad_accs
.
append
(
grad_acc
)
def
_make_param_hook
(
self
,
param
):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def
param_hook
(
*
unused
):
# Add the gradient to the buffer.
if
param
.
grad
.
data
is
not
None
:
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
# Now we can deallocate grad memory.
param
.
grad
=
None
return
param_hook
def
zero_grad_buffer
(
self
):
"""Set the grad buffer data to zero. Needs to be called at the
begining of each iteration."""
assert
self
.
_grad_buffers
is
not
None
,
'buffers are not initialized.'
for
_
,
buffer_
in
self
.
_grad_buffers
.
items
():
buffer_
.
zero
()
def
allreduce_gradients
(
self
):
"""Reduce gradients across data parallel ranks."""
# If we have buffers, simply reduce the data in the buffer.
if
self
.
_grad_buffers
is
not
None
:
for
_
,
buffer_
in
self
.
_grad_buffers
.
items
():
buffer_
.
data
/=
mpu
.
get_data_parallel_world_size
()
torch
.
distributed
.
all_reduce
(
buffer_
.
data
,
group
=
mpu
.
get_data_parallel_group
())
else
:
# Otherwise, bucketize and all-reduce
buckets
=
{}
# Pack the buckets.
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
tp
=
param
.
data
.
type
()
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
param
)
param
.
main_grad
=
param
.
grad
# For each bucket, all-reduce and copy all-reduced grads.
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
/=
mpu
.
get_data_parallel_world_size
()
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
mpu
.
get_data_parallel_group
())
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
megatron/model/module.py
View file @
1a2cb60c
...
...
@@ -25,6 +25,7 @@ from megatron import mpu
_FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
_HALF_TYPES
=
(
torch
.
HalfTensor
,
torch
.
cuda
.
HalfTensor
)
_BF16_TYPES
=
(
torch
.
BFloat16Tensor
,
torch
.
cuda
.
BFloat16Tensor
)
...
...
@@ -109,6 +110,7 @@ class MegatronModule(torch.nn.Module):
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
def
conversion_helper
(
val
,
conversion
):
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
...
...
@@ -120,44 +122,56 @@ def conversion_helper(val, conversion):
return
rtn
def
fp32_to_f
p
16
(
val
):
"""Convert fp32 `val` to fp16"""
def
fp32_to_f
loat
16
(
val
,
float16_convertor
):
"""Convert fp32 `val` to fp16
/bf16
"""
def
half_conversion
(
val
):
val_typecheck
=
val
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
val_typecheck
=
val
.
data
if
isinstance
(
val_typecheck
,
_FLOAT_TYPES
):
val
=
val
.
half
(
)
val
=
float16_convertor
(
val
)
return
val
return
conversion_helper
(
val
,
half_conversion
)
def
f
p
16_to_fp32
(
val
):
"""Convert fp16 `val` to fp32"""
def
f
loat
16_to_fp32
(
val
):
"""Convert fp16
/bf16
`val` to fp32"""
def
float_conversion
(
val
):
val_typecheck
=
val
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
val_typecheck
=
val
.
data
if
isinstance
(
val_typecheck
,
_HALF_TYPES
):
if
isinstance
(
val_typecheck
,
(
_BF16_TYPES
,
_HALF_TYPES
)
)
:
val
=
val
.
float
()
return
val
return
conversion_helper
(
val
,
float_conversion
)
class
FP16Module
(
MegatronModule
):
class
Float16Module
(
MegatronModule
):
def
__init__
(
self
,
module
,
args
):
super
(
Float16Module
,
self
).
__init__
()
if
args
.
fp16
:
self
.
add_module
(
'module'
,
module
.
half
())
def
float16_convertor
(
val
):
return
val
.
half
()
elif
args
.
bf16
:
self
.
add_module
(
'module'
,
module
.
bfloat16
())
def
float16_convertor
(
val
):
return
val
.
bfloat16
()
else
:
raise
Exception
(
'should not be here'
)
def
__init__
(
self
,
module
):
super
(
FP16Module
,
self
).
__init__
()
self
.
add_module
(
'module'
,
module
.
half
())
self
.
float16_convertor
=
float16_convertor
def
forward
(
self
,
*
inputs
,
**
kwargs
):
if
mpu
.
is_pipeline_first_stage
():
inputs
=
fp32_to_f
p
16
(
inputs
)
inputs
=
fp32_to_f
loat
16
(
inputs
,
self
.
float16_convertor
)
outputs
=
self
.
module
(
*
inputs
,
**
kwargs
)
if
mpu
.
is_pipeline_last_stage
():
outputs
=
f
p
16_to_fp32
(
outputs
)
outputs
=
f
loat
16_to_fp32
(
outputs
)
return
outputs
...
...
megatron/model/transformer.py
View file @
1a2cb60c
...
...
@@ -397,8 +397,11 @@ class ParallelTransformerLayer(MegatronModule):
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Layernorm on the input data.
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
LayerNorm
=
import_layernorm
(
self
.
fp32_residual_connection
,
self
.
bf16
)
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
...
...
@@ -440,6 +443,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
layernorm_output
=
layernorm_output
.
bfloat16
()
# Self attention.
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
...
...
@@ -478,6 +483,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
layernorm_output
=
layernorm_output
.
bfloat16
()
if
self
.
layer_type
==
LayerType
.
decoder
:
attention_output
,
attention_bias
=
\
...
...
@@ -500,6 +507,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
layernorm_output
=
layernorm_output
.
bfloat16
()
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
...
...
@@ -533,6 +542,7 @@ class ParallelTransformer(MegatronModule):
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Store activation checkpoiting flag.
...
...
@@ -578,7 +588,8 @@ class ParallelTransformer(MegatronModule):
if
mpu
.
is_pipeline_last_stage
():
# Final layer norm before output.
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
LayerNorm
=
import_layernorm
(
self
.
fp32_residual_connection
,
self
.
bf16
)
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
...
...
@@ -665,6 +676,8 @@ class ParallelTransformer(MegatronModule):
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
self
.
final_layernorm
(
hidden_states
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
output
=
output
.
bfloat16
()
else
:
output
=
hidden_states
if
get_key_value
:
...
...
megatron/optimizer/__init__.py
View file @
1a2cb60c
...
...
@@ -20,7 +20,7 @@ from megatron import get_args
from
megatron.model
import
import_layernorm
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.optimizer
import
F
P
16OptimizerWithF
P
16Params
,
FP32Optimizer
from
.optimizer
import
F
loat
16OptimizerWithF
loat
16Params
,
FP32Optimizer
def
_get_params_for_weight_decay_optimization
(
modules
):
...
...
@@ -28,7 +28,7 @@ def _get_params_for_weight_decay_optimization(modules):
Layernorms and baises will have no weight decay but the rest will.
"""
args
=
get_args
()
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
,
args
.
bf16
)
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
...
...
@@ -67,24 +67,45 @@ def get_megatron_optimizer(model):
momentum
=
args
.
sgd_momentum
)
else
:
raise
Exception
(
'{} optimizer is not supported.'
.
format
(
args
.
optimizer
))
args
.
optimizer
))
if
args
.
fp16
:
# Determine whether the params have main-grad field.
params_have_main_grad
=
False
if
args
.
DDP_impl
==
'local'
:
params_have_main_grad
=
True
if
args
.
fp16
or
args
.
bf16
:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler
=
None
# Constant loss scale.
if
args
.
loss_scale
:
grad_scaler
=
ConstantGradScaler
(
args
.
loss_scale
)
# Dynamic loss scale.
else
:
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
args
.
initial_loss_scale
,
min_scale
=
args
.
min_loss_scale
,
growth_factor
=
2.0
,
backoff_factor
=
0.5
,
growth_interval
=
args
.
loss_scale_window
,
hysteresis
=
args
.
hysteresis
)
if
args
.
fp16
:
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
args
.
initial_loss_scale
,
min_scale
=
args
.
min_loss_scale
,
growth_factor
=
2.0
,
backoff_factor
=
0.5
,
growth_interval
=
args
.
loss_scale_window
,
hysteresis
=
args
.
hysteresis
)
# Megatron optimizer.
return
FP16OptimizerWithFP16Params
(
optimizer
,
grad_scaler
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
)
return
Float16OptimizerWithFloat16Params
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
bf16
,
grad_scaler
)
# FP32.
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
)
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
)
megatron/optimizer/optimizer.py
View file @
1a2cb60c
...
...
@@ -46,24 +46,37 @@ def _zero_grad_group_helper(group, set_to_none):
def
_multi_tensor_copy_this_to_that
(
this
,
that
,
overflow_buf
=
None
):
"""Use multi-tensor-applier to copy values from one list to another."""
"""Use multi-tensor-applier to copy values from one list to another.
We don't have a blfoat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
with bfloat16."""
if
overflow_buf
:
overflow_buf
.
fill_
(
0
)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
overflow_buf
,
[
this
,
that
],
1.0
)
else
:
overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
overflow_buf
,
[
this
,
that
],
1.0
)
for
this_
,
that_
in
zip
(
this
,
that
):
that_
.
copy_
(
this_
)
class
MegatronOptimizer
(
ABC
):
def
__init__
(
self
,
optimizer
):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
):
"""Input optimizer is the base optimizer for example Adam."""
self
.
optimizer
=
optimizer
assert
self
.
optimizer
,
'no optimizer is provided.'
# Set gradient clipping and logging params.
self
.
clip_grad
=
clip_grad
self
.
log_num_zeros_in_grad
=
log_num_zeros_in_grad
self
.
params_have_main_grad
=
params_have_main_grad
def
get_parameters
(
self
):
params
=
[]
...
...
@@ -72,31 +85,38 @@ class MegatronOptimizer(ABC):
params
.
append
(
param
)
return
params
def
clip_grad_norm
(
self
,
clip_grad
):
params
=
self
.
get_parameters
()
return
clip_grad_norm_fp32
(
params
,
clip_grad
)
def
count_zeros
(
self
):
params
=
self
.
get_parameters
()
return
count_zeros_fp32
(
params
)
@
abstractmethod
def
zero_grad
(
self
,
set_to_none
=
True
):
pass
@
abstractmethod
def
get_loss_scale
(
self
):
"""The output should be a cuda tensor of size 1."""
pass
def
scale_loss
(
self
,
loss
):
"""Simple scaling."""
return
self
.
get_loss_scale
()
*
loss
@
abstractmethod
def
step
(
self
):
pass
@
abstractmethod
def
reload_model_params
(
self
):
"""Refreshes any internal state from the current model parameters.
...
...
@@ -106,14 +126,17 @@ class MegatronOptimizer(ABC):
with main parameters, the main parameters need to also be updated."""
pass
@
abstractmethod
def
state_dict
(
self
):
pass
@
abstractmethod
def
load_state_dict
(
self
,
state_dict
):
pass
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
def
_get_state
(
self
):
...
...
@@ -124,6 +147,7 @@ class MegatronOptimizer(ABC):
state
=
property
(
_get_state
,
_set_state
)
# Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
...
...
@@ -137,50 +161,90 @@ class MegatronOptimizer(ABC):
class
FP16OptimizerWithFP16Params
(
MegatronOptimizer
):
def
__init__
(
self
,
optimizer
,
grad_scaler
,
clip_grad
,
log_num_zeros_in_grad
):
super
(
FP16OptimizerWithFP16Params
,
self
).
__init__
(
optimizer
)
class
Float16OptimizerWithFloat16Params
(
MegatronOptimizer
):
"""Float16 optimizer for fp16 and bf16 data types.
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a contihuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
"""
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
bf16
,
grad_scaler
):
super
(
Float16OptimizerWithFloat16Params
,
self
).
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
)
self
.
bf16
=
bf16
self
.
grad_scaler
=
grad_scaler
self
.
clip_grad
=
clip_grad
self
.
log_num_zeros_in_grad
=
log_num_zeros_in_grad
# None grad scaler is only supported for bf16.
if
self
.
grad_scaler
is
None
:
assert
self
.
bf16
,
'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
self
.
found_inf
=
torch
.
cuda
.
FloatTensor
([
0.0
])
# Note that we keep this for the cases that grad scaler is none.
# We still record nan/inf if we have a bfloat16 with a grad scaler.
if
self
.
grad_scaler
:
self
.
found_inf
=
torch
.
cuda
.
FloatTensor
([
0.0
])
# Dummy tensor needed for apex multi-apply tensor.
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# For bfloat, we don't have multi-tensor apply and for now
# we set it to none so the multi-tensor apply gets ignored.
if
bf16
:
self
.
_dummy_overflow_buf
=
None
else
:
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# In case grad scaler is not passed, define the unity scale.
if
self
.
grad_scaler
is
None
:
self
.
_scale_one
=
torch
.
cuda
.
FloatTensor
([
1.0
])
# ======================
# main parameter stuff
# ======================
# Three groups of parameters:
# f
p
16_groups: original f
p
16 parameters
# fp32_from_f
p
16_groups: fp32 copy of f
p
16 parameters
# f
loat
16_groups: original f
loat
16 parameters
# fp32_from_f
loat
16_groups: fp32 copy of f
loat
16 parameters
# fp32_from_fp32_groups: original fp32 parameters
self
.
f
p
16_groups
=
[]
self
.
fp32_from_f
p
16_groups
=
[]
self
.
f
loat
16_groups
=
[]
self
.
fp32_from_f
loat
16_groups
=
[]
self
.
fp32_from_fp32_groups
=
[]
# For all the groups in the original optimizer:
for
param_group
in
self
.
optimizer
.
param_groups
:
f
p
16_params_this_group
=
[]
f
loat
16_params_this_group
=
[]
fp32_params_this_group
=
[]
fp32_from_f
p
16_params_this_group
=
[]
fp32_from_f
loat
16_params_this_group
=
[]
# For all the parameters in this group:
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
if
param
.
requires_grad
:
# fp16 params:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
fp16_params_this_group
.
append
(
param
)
# float16 params:
if
param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
]:
float16_params_this_group
.
append
(
param
)
# Create a copy
main_param
=
param
.
detach
().
clone
().
float
()
# Store grads
main_param
.
requires_grad
=
True
# Copy tensor model parallel attributes.
mpu
.
copy_tensor_model_parallel_attributes
(
main_param
,
param
)
...
...
@@ -188,7 +252,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
main_param
.
shared
=
param
.
shared
# Replace the optimizer params with the new fp32 copy.
param_group
[
'params'
][
i
]
=
main_param
fp32_from_f
p
16_params_this_group
.
append
(
main_param
)
fp32_from_f
loat
16_params_this_group
.
append
(
main_param
)
# Reset existing state dict key to the new main param.
if
param
in
self
.
optimizer
.
state
:
self
.
optimizer
.
state
[
main_param
]
\
...
...
@@ -200,13 +264,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
param_group
[
'params'
][
i
]
=
param
else
:
raise
TypeError
(
"Wrapped parameters must be either "
"torch.cuda.FloatTensor or "
"torch.cuda.HalfTensor. "
"Received {}"
.
format
(
param
.
type
()))
self
.
fp16_groups
.
append
(
fp16_params_this_group
)
self
.
fp32_from_fp16_groups
.
append
(
fp32_from_fp16_params_this_group
)
raise
TypeError
(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'
.
format
(
param
.
type
()))
self
.
float16_groups
.
append
(
float16_params_this_group
)
self
.
fp32_from_float16_groups
.
append
(
fp32_from_float16_params_this_group
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
# Leverage state_dict() and load_state_dict() to
...
...
@@ -216,37 +282,40 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
f
p
16_groups & fp32_from_fp32_groups."""
for
group
in
self
.
f
p
16_groups
:
f
loat
16_groups & fp32_from_fp32_groups."""
for
group
in
self
.
f
loat
16_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
for
group
in
self
.
fp32_from_fp32_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
def
get_loss_scale
(
self
):
if
self
.
grad_scaler
is
None
:
return
self
.
_scale_one
return
self
.
grad_scaler
.
scale
def
_copy_model_grads_to_main_grads
(
self
):
# This only needs to be done for the fp16 group.
model_grads
=
[]
main_grads
=
[]
for
model_group
,
main_group
in
zip
(
self
.
fp16_groups
,
self
.
fp32_from_fp16_groups
):
# This only needs to be done for the float16 group.
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
if
model_param
.
grad
is
not
None
:
if
main_param
.
grad
is
None
:
main_param
.
grad
=
torch
.
empty_like
(
main_param
)
model_grads
.
append
(
model_param
.
grad
.
data
)
main_grads
.
append
(
main_param
.
grad
.
data
)
_multi_tensor_copy_this_to_that
(
this
=
model_grads
,
that
=
main_grads
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
if
self
.
params_have_main_grad
:
main_param
.
grad
=
model_param
.
main_grad
.
float
()
else
:
if
model_param
.
grad
is
not
None
:
main_param
.
grad
=
model_param
.
grad
.
float
()
# For fp32 grads, we need to reset the grads to main grad.
if
self
.
params_have_main_grad
:
for
model_group
in
self
.
fp32_from_fp32_groups
:
for
model_param
in
model_group
:
model_param
.
grad
=
model_param
.
main_grad
def
_unscale_main_grads_and_check_for_nan
(
self
):
main_grads
=
[]
# fp32 params fromm f
p
16 ones.
for
main_group
in
self
.
fp32_from_f
p
16_groups
:
# fp32 params fromm f
loat
16 ones.
for
main_group
in
self
.
fp32_from_f
loat
16_groups
:
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
...
...
@@ -270,11 +339,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return
found_inf_flag
def
_get_model_and_main_params_data_f
p
16
(
self
):
def
_get_model_and_main_params_data_f
loat
16
(
self
):
model_data
=
[]
main_data
=
[]
for
model_group
,
main_group
in
zip
(
self
.
f
p
16_groups
,
self
.
fp32_from_f
p
16_groups
):
for
model_group
,
main_group
in
zip
(
self
.
f
loat
16_groups
,
self
.
fp32_from_f
loat
16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_data
.
append
(
model_param
.
data
)
main_data
.
append
(
main_param
.
data
)
...
...
@@ -282,15 +351,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def
_copy_main_params_to_model_params
(
self
):
# Only needed for the f
p
16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_f
p
16
()
# Only needed for the f
loat
16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_f
loat
16
()
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
def
_copy_model_params_to_main_params
(
self
):
# Only needed for the f
p
16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_f
p
16
()
# Only needed for the f
loat
16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_f
loat
16
()
_multi_tensor_copy_this_to_that
(
this
=
model_data
,
that
=
main_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
...
...
@@ -298,6 +367,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def
reload_model_params
(
self
):
self
.
_copy_model_params_to_main_params
()
@
torch
.
no_grad
()
def
step
(
self
):
...
...
@@ -308,18 +378,22 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
self
.
_copy_model_grads_to_main_grads
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# Unscale and check for inf/nan.
timers
(
'optimizer-unscale-and-check-inf'
).
start
()
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if
self
.
grad_scaler
:
# We are done with scaling gradients
# so we can update the loss scale.
self
.
grad_scaler
.
update
(
found_inf_flag
)
# Unscale and check for inf/nan.
timers
(
'optimizer-unscale-and-check-inf'
).
start
()
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
# If we found inf/nan, skip the update.
if
found_inf_flag
:
return
False
,
None
,
None
# We are done with scaling gradients
# so we can update the loss scale.
self
.
grad_scaler
.
update
(
found_inf_flag
)
# If we found inf/nan, skip the update.
if
found_inf_flag
:
return
False
,
None
,
None
# Clip the main gradients.
timers
(
'optimizer-clip-main-grad'
).
start
()
...
...
@@ -329,7 +403,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
timers
(
'optimizer-clip-main-grad'
).
stop
()
# count the zeros in the grads
num_zeros_in_grad
=
self
.
count_zeros
()
if
self
.
log_num_zeros_in_grad
else
None
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
# Step the optimizer.
self
.
optimizer
.
step
()
...
...
@@ -346,8 +421,9 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def
state_dict
(
self
):
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'fp32_from_fp16_params'
]
=
self
.
fp32_from_fp16_groups
if
self
.
grad_scaler
:
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'fp32_from_fp16_params'
]
=
self
.
fp32_from_float16_groups
return
state_dict
...
...
@@ -365,15 +441,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
'load grad scaler ...'
)
else
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
if
self
.
grad_scaler
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
else
:
print_rank_0
(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
# Copy data for the main params.
fp32_from_f
p
16_params_key
=
'fp32_from_fp16_params'
if
fp32_from_f
p
16_params_key
not
in
state_dict
:
fp32_from_f
p
16_params_key
=
'fp32_from_fp16'
fp32_from_f
loat
16_params_key
=
'fp32_from_fp16_params'
if
fp32_from_f
loat
16_params_key
not
in
state_dict
:
fp32_from_f
loat
16_params_key
=
'fp32_from_fp16'
for
current_group
,
saved_group
in
zip
(
self
.
fp32_from_f
p
16_groups
,
state_dict
[
fp32_from_f
p
16_params_key
]):
self
.
fp32_from_f
loat
16_groups
,
state_dict
[
fp32_from_f
loat
16_params_key
]):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
...
...
@@ -381,11 +462,14 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
class
FP32Optimizer
(
MegatronOptimizer
):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
):
super
(
FP32Optimizer
,
self
).
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
)
super
(
FP32Optimizer
,
self
).
__init__
(
optimizer
)
self
.
clip_grad
=
clip_grad
self
.
log_num_zeros_in_grad
=
log_num_zeros_in_grad
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
1.0
])
...
...
@@ -405,13 +489,20 @@ class FP32Optimizer(MegatronOptimizer):
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
# Copy main_grads to grads.
if
self
.
params_have_main_grad
:
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
param
.
grad
=
param
.
main_grad
# Clip gradients.
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
# count the zeros in the grads
num_zeros_in_grad
=
self
.
count_zeros
()
if
self
.
log_num_zeros_in_grad
else
None
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
# Update parameters.
self
.
optimizer
.
step
()
...
...
megatron/training.py
View file @
1a2cb60c
...
...
@@ -37,9 +37,8 @@ from megatron import print_rank_0
from
megatron
import
print_rank_last
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.model
import
F
P
16Module
from
megatron.model
import
F
loat
16Module
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.learning_rates
import
AnnealingLR
...
...
@@ -54,6 +53,7 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving
from
megatron.utils
import
report_memory
def
print_datetime
(
string
):
"""Note that this call will sync across all ranks."""
torch
.
distributed
.
barrier
()
...
...
@@ -222,8 +222,18 @@ def get_model(model_provider_func):
model_module
.
cuda
(
torch
.
cuda
.
current_device
())
# Fp16 conversion.
if
args
.
fp16
:
model
=
[
FP16Module
(
model_module
)
for
model_module
in
model
]
if
args
.
fp16
or
args
.
bf16
:
model
=
[
Float16Module
(
model_module
,
args
)
for
model_module
in
model
]
# For now, the layer norm does not support input float32 and outut bf16.
# For this, we move layernorm parameters to fp32 and cast output of the
# layernorm operation back to bf16.
if
args
.
bf16
and
args
.
fp32_residual_connection
:
from
megatron.model
import
import_layernorm
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
,
args
.
bf16
)
for
model_
in
model
:
for
module_
in
model_
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
module_
.
float
()
if
args
.
DDP_impl
==
'torch'
:
i
=
torch
.
cuda
.
current_device
()
...
...
@@ -231,8 +241,12 @@ def get_model(model_provider_func):
process_group
=
mpu
.
get_data_parallel_group
())
for
model_module
in
model
]
return
model
if
args
.
DDP_impl
==
'local'
:
model
=
[
LocalDDP
(
model_module
)
for
model_module
in
model
]
model
=
[
LocalDDP
(
model_module
,
args
.
accumulate_allreduce_grads_in_fp32
,
args
.
use_contiguous_buffers_in_ddp
)
for
model_module
in
model
]
return
model
raise
NotImplementedError
(
'Unknown DDP implementation specified: {}. '
...
...
@@ -289,7 +303,7 @@ def setup_model_and_optimizer(model_provider_func):
model
=
get_model
(
model_provider_func
)
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
F
P
16Module
))
(
torchDDP
,
LocalDDP
,
F
loat
16Module
))
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
...
...
@@ -308,9 +322,7 @@ def setup_model_and_optimizer(model_provider_func):
args
.
iteration
=
0
# We only support local DDP with multiple micro-batches.
if
len
(
model
)
>
1
:
assert
args
.
DDP_impl
==
'local'
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
len
(
model
)
>
1
or
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
assert
args
.
DDP_impl
==
'local'
# get model without FP16 and/or TorchDDP wrappers
...
...
@@ -331,7 +343,11 @@ def train_step(forward_step_func, data_iterator,
timers
=
get_timers
()
# Set grad to zero.
optimizer
.
zero_grad
()
if
args
.
DDP_impl
==
'local'
and
args
.
use_contiguous_buffers_in_ddp
:
for
partition
in
model
:
partition
.
zero_grad_buffer
()
else
:
optimizer
.
zero_grad
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
...
...
@@ -351,8 +367,7 @@ def train_step(forward_step_func, data_iterator,
if
args
.
DDP_impl
==
'local'
:
timers
(
'backward-params-all-reduce'
).
start
()
for
model_module
in
model
:
model_module
.
allreduce_params
(
reduce_after
=
False
,
fp32_allreduce
=
args
.
fp32_allreduce
)
model_module
.
allreduce_gradients
()
timers
(
'backward-params-all-reduce'
).
stop
()
# All-reduce word_embeddings' grad across first and last stages to ensure
...
...
@@ -368,12 +383,15 @@ def train_step(forward_step_func, data_iterator,
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
F
P
16Module
))
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
F
loat
16Module
))
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
torch
.
distributed
.
all_reduce
(
word_embeddings_weight
.
grad
,
group
=
mpu
.
get_embedding_group
())
if
args
.
DDP_impl
==
'local'
:
grad
=
word_embeddings_weight
.
main_grad
else
:
grad
=
word_embeddings_weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
timers
(
'backward-embedding-all-reduce'
).
stop
()
# Update parameters.
...
...
megatron/utils.py
View file @
1a2cb60c
...
...
@@ -48,13 +48,20 @@ def unwrap_model(model, module_instances=(torchDDP)):
def
calc_params_l2_norm
(
model
):
"""Calculate l2 norm of parameters """
args
=
get_args
()
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
# Remove duplicate params.
params_data
=
[]
for
param
in
model
.
parameters
():
is_not_shared
=
param_is_not_shared
(
param
)
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
if
is_not_shared
and
is_not_tp_duplicate
:
params_data
.
append
(
param
.
data
)
for
model_
in
model
:
for
param
in
model_
.
parameters
():
is_not_shared
=
param_is_not_shared
(
param
)
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
if
is_not_shared
and
is_not_tp_duplicate
:
if
args
.
bf16
:
params_data
.
append
(
param
.
data
.
float
())
else
:
params_data
.
append
(
param
.
data
)
# Calculate norm
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
norm
,
_
=
multi_tensor_applier
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment