Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b4bc51b1
"csrc/vscode:/vscode.git/clone" did not exist on "9ccee9c051cfabcdf2919fa2c1f69c11a72bf23d"
Commit
b4bc51b1
authored
Mar 08, 2021
by
Mohammad Shoeybi
Committed by
Jared Casper
Mar 08, 2021
Browse files
Bfloat with fp32 grad acc
parent
87b8b9dc
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
504 additions
and
211 deletions
+504
-211
megatron/arguments.py
megatron/arguments.py
+25
-4
megatron/model/__init__.py
megatron/model/__init__.py
+5
-3
megatron/model/bert_model.py
megatron/model/bert_model.py
+1
-1
megatron/model/distributed.py
megatron/model/distributed.py
+178
-72
megatron/model/module.py
megatron/model/module.py
+26
-12
megatron/model/transformer.py
megatron/model/transformer.py
+15
-2
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+35
-14
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+174
-83
megatron/training.py
megatron/training.py
+33
-15
megatron/utils.py
megatron/utils.py
+12
-5
No files found.
megatron/arguments.py
View file @
b4bc51b1
...
@@ -129,11 +129,26 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -129,11 +129,26 @@ def parse_args(extra_args_provider=None, defaults={},
# Parameters dtype.
# Parameters dtype.
args
.
params_dtype
=
torch
.
float
args
.
params_dtype
=
torch
.
float
if
args
.
fp16
:
if
args
.
fp16
:
assert
not
args
.
bf16
args
.
params_dtype
=
torch
.
half
args
.
params_dtype
=
torch
.
half
if
args
.
bf16
:
assert
not
args
.
fp16
args
.
params_dtype
=
torch
.
bfloat16
# No fusion is support for bfloat for now
assert
not
args
.
masked_softmax_fusion
assert
not
args
.
bias_gelu_fusion
assert
not
args
.
bias_dropout_fusion
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
'using {} for parameters ...'
.
format
(
args
.
params_dtype
),
print
(
'using {} for parameters ...'
.
format
(
args
.
params_dtype
),
flush
=
True
)
flush
=
True
)
# If we do accumulation and all-reduces in fp32, we need to have
# local DDP and we should set the use-contiguous-buffers-in-ddp.
if
args
.
accumulate_allreduce_grads_in_fp32
:
assert
args
.
DDP_impl
==
'local'
args
.
use_contiguous_buffers_in_ddp
=
True
if
args
.
dataloader_type
is
None
:
if
args
.
dataloader_type
is
None
:
args
.
dataloader_type
=
'single'
args
.
dataloader_type
=
'single'
...
@@ -204,8 +219,8 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -204,8 +219,8 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
fp16_lm_cross_entropy
:
if
args
.
fp16_lm_cross_entropy
:
assert
args
.
fp16
,
'lm cross entropy in fp16 only support in fp16 mode.'
assert
args
.
fp16
,
'lm cross entropy in fp16 only support in fp16 mode.'
if
args
.
fp32_residual_connection
:
if
args
.
fp32_residual_connection
:
assert
args
.
fp16
,
\
assert
args
.
fp16
or
args
.
bf16
,
\
'residual connection in fp32 only supported when using fp16.'
'residual connection in fp32 only supported when using fp16
or bf16
.'
# Activation checkpointing.
# Activation checkpointing.
if
args
.
distribute_checkpointed_activations
:
if
args
.
distribute_checkpointed_activations
:
assert
args
.
checkpoint_activations
,
\
assert
args
.
checkpoint_activations
,
\
...
@@ -528,6 +543,8 @@ def _add_mixed_precision_args(parser):
...
@@ -528,6 +543,8 @@ def _add_mixed_precision_args(parser):
group
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
group
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'Run model in fp16 mode.'
)
help
=
'Run model in fp16 mode.'
)
group
.
add_argument
(
'--bf16'
,
action
=
'store_true'
,
help
=
'Run model in bfloat16 mode.'
)
group
.
add_argument
(
'--loss-scale'
,
type
=
float
,
default
=
None
,
group
.
add_argument
(
'--loss-scale'
,
type
=
float
,
default
=
None
,
help
=
'Static loss scaling, positive power of 2 '
help
=
'Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic'
'values can improve fp16 convergence. If None, dynamic'
...
@@ -549,8 +566,9 @@ def _add_mixed_precision_args(parser):
...
@@ -549,8 +566,9 @@ def _add_mixed_precision_args(parser):
help
=
'Run attention masking and softmax in fp32. '
help
=
'Run attention masking and softmax in fp32. '
'This flag is ignored unless '
'This flag is ignored unless '
'--no-query-key-layer-scaling is specified.'
)
'--no-query-key-layer-scaling is specified.'
)
group
.
add_argument
(
'--fp32-allreduce'
,
action
=
'store_true'
,
group
.
add_argument
(
'--accumulate-allreduce-grads-in-fp32'
,
help
=
'All-reduce in fp32'
)
action
=
'store_true'
,
help
=
'Gradient accumulation and all-reduce in fp32.'
)
group
.
add_argument
(
'--fp16-lm-cross-entropy'
,
action
=
'store_true'
,
group
.
add_argument
(
'--fp16-lm-cross-entropy'
,
action
=
'store_true'
,
help
=
'Move the cross entropy unreduced loss calculation'
help
=
'Move the cross entropy unreduced loss calculation'
'for lm head to fp16.'
)
'for lm head to fp16.'
)
...
@@ -577,6 +595,9 @@ def _add_distributed_args(parser):
...
@@ -577,6 +595,9 @@ def _add_distributed_args(parser):
choices
=
[
'local'
,
'torch'
],
choices
=
[
'local'
,
'torch'
],
help
=
'which DistributedDataParallel implementation '
help
=
'which DistributedDataParallel implementation '
'to use.'
)
'to use.'
)
group
.
add_argument
(
'--use-contiguous-buffers-in-ddp'
,
action
=
'store_true'
,
help
=
'If set, use contiguous buffer in DDP. Note that '
'this option only works woth local DDP.'
)
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
dest
=
'scatter_gather_tensors_in_pipeline'
)
dest
=
'scatter_gather_tensors_in_pipeline'
)
...
...
megatron/model/__init__.py
View file @
b4bc51b1
...
@@ -16,11 +16,13 @@
...
@@ -16,11 +16,13 @@
_LAYER_NORM
=
None
_LAYER_NORM
=
None
def
import_layernorm
(
fp32_residual_connection
):
def
import_layernorm
(
fp32_residual_connection
,
bf16
):
global
_LAYER_NORM
global
_LAYER_NORM
if
not
_LAYER_NORM
:
if
not
_LAYER_NORM
:
if
fp32_residual_connection
:
if
bf16
:
from
torch.nn
import
LayerNorm
elif
fp32_residual_connection
:
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
else
:
else
:
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
as
LayerNorm
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
as
LayerNorm
...
@@ -39,6 +41,6 @@ from .gpt_model import (GPTModel,
...
@@ -39,6 +41,6 @@ from .gpt_model import (GPTModel,
GPTModelIntermediateStage
,
GPTModelIntermediateStage
,
GPTModelLastStage
)
GPTModelLastStage
)
from
.language_model
import
get_language_model
from
.language_model
import
get_language_model
from
.module
import
F
P
16Module
from
.module
import
F
loat
16Module
megatron/model/bert_model.py
View file @
b4bc51b1
...
@@ -78,7 +78,7 @@ class BertLMHead(MegatronModule):
...
@@ -78,7 +78,7 @@ class BertLMHead(MegatronModule):
self
.
parallel_output
=
parallel_output
self
.
parallel_output
=
parallel_output
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
,
args
.
bf16
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
if
args
.
openai_gelu
:
if
args
.
openai_gelu
:
...
...
megatron/model/distributed.py
View file @
b4bc51b1
...
@@ -13,100 +13,206 @@
...
@@ -13,100 +13,206 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
abc
import
ABC
from
abc
import
abstractmethod
import
torch
import
torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
import
torch.distributed
as
dist
from
torch.nn.modules
import
Module
from
torch.autograd
import
Variable
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
class
DistributedDataParallel
(
MegatronModule
):
def
__init__
(
self
,
module
):
class
MemoryBuffer
:
super
(
DistributedDataParallel
,
self
).
__init__
()
self
.
warn_on_half
=
True
if
dist
.
_backend
==
dist
.
dist_backend
.
GLOO
else
False
def
__init__
(
self
,
numel
,
dtype
):
self
.
numel
=
numel
self
.
dtype
=
dtype
self
.
data
=
torch
.
zeros
(
self
.
numel
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
def
zero
(
self
):
"""Reset the buffer to zero."""
self
.
data
.
zero_
()
def
get
(
self
,
shape
,
start_index
):
"""Return a tensor with the input `shape` as a view into the
1-D data starting at `start_index`."""
end_index
=
start_index
+
shape
.
numel
()
assert
end_index
<=
self
.
numel
,
\
'requested tensor is out of the buffer range.'
buffer_tensor
=
self
.
data
[
start_index
:
end_index
]
buffer_tensor
=
buffer_tensor
.
view
(
shape
)
return
buffer_tensor
class
DistributedDataParallelBase
(
MegatronModule
,
ABC
):
"""Abstract class for DDP."""
def
__init__
(
self
,
module
):
super
(
DistributedDataParallelBase
,
self
).
__init__
()
# Keep a pointer to the model.
self
.
module
=
module
self
.
module
=
module
self
.
data_parallel_group
=
mpu
.
get_data_parallel_group
()
def
allreduce_params
(
reduce_after
=
True
,
no_scale
=
False
,
fp32_allreduce
=
False
):
@
abstractmethod
if
(
self
.
needs_reduction
):
def
allreduce_gradients
(
self
):
self
.
needs_reduction
=
False
pass
buckets
=
{}
for
name
,
param
in
self
.
module
.
named_parameters
():
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
tp
=
(
param
.
data
.
type
())
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
param
)
if
self
.
warn_on_half
:
if
torch
.
cuda
.
HalfTensor
in
buckets
:
print
(
"WARNING: gloo dist backend for half parameters may be extremely slow."
+
" It is recommended to use the NCCL backend in this case."
)
self
.
warn_on_half
=
False
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
if
fp32_allreduce
:
coalesced
=
coalesced
.
float
()
if
not
no_scale
and
not
reduce_after
:
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
dist
.
all_reduce
(
coalesced
,
group
=
self
.
data_parallel_group
)
torch
.
cuda
.
synchronize
()
if
not
no_scale
and
reduce_after
:
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
self
.
hook_handles
=
[]
self
.
hooks
=
[]
for
param
in
list
(
self
.
module
.
parameters
()):
def
allreduce_hook
(
*
unused
):
Variable
.
_execution_engine
.
queue_callback
(
allreduce_params
)
# handle = param.register_hook(allreduce_hook)
# self.hooks.append(allreduce_hook)
# self.hook_handles.append(handle)
self
.
allreduce_params
=
allreduce_params
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
self
.
needs_reduction
=
True
return
self
.
module
(
*
inputs
,
**
kwargs
)
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
#[h.remove() for h in self.hook_handles]
return
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
sd
=
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
# for handle, hook in zip(self.hook_handles, self.hooks):
# d = handle.hooks_dict_ref()
# d[handle.id] = hook
return
sd
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
return
self
.
module
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
return
self
.
module
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
keep_vars
)
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
'''
def _sync_buffers(self):
buffers = list(self.module._all_buffers())
class
DistributedDataParallel
(
DistributedDataParallelBase
):
if len(buffers) > 0:
"""DDP with contiguous buffers options to storre and accumulate gradients.
# cross-node buffer sync
This class:
flat_buffers = _flatten_dense_tensors(buffers)
- has the potential to reduce memory fragmentation.
dist.broadcast(flat_buffers, 0)
- provides the option to do the gradient accumulation
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
in a type other than the params type (for example fp32)
buf.copy_(synced)
def train(self, mode=True):
Arguments:
# Clear NCCL communicator and CUDA event cache of the default group ID,
module: input model.
# These cache will be recreated at the later call. This is currently a
accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
# work-around for a potential NCCL deadlock.
and the gradient all-reduce all in in float32. If this option is
if dist._backend == dist.dist_backend.NCCL:
true, we require `use_contiguous_buffers` to be true too.
dist._clear_group_cache()
use_contiguous_buffers: if true, use a contiguous buffer to store the
super(DistributedDataParallel, self).train(mode)
gradients.
self.module.train(mode)
"""
'''
def
__init__
(
self
,
module
,
accumulate_allreduce_grads_in_fp32
,
use_contiguous_buffers
):
super
(
DistributedDataParallel
,
self
).
__init__
(
module
)
self
.
accumulate_allreduce_grads_in_fp32
\
=
accumulate_allreduce_grads_in_fp32
self
.
use_contiguous_buffers
=
use_contiguous_buffers
# If we are using fp32-accumulate-allreduce explicitly
# this means we need main grads in a continous buffer.
if
self
.
accumulate_allreduce_grads_in_fp32
:
assert
self
.
use_contiguous_buffers
# ===================================
# Rest of this part applies only to
# the case we use continuous buffers.
# ===================================
self
.
_grad_buffers
=
None
if
self
.
use_contiguous_buffers
:
self
.
_grad_buffers
=
{}
# Simple function to define buffer type.
def
_get_buffer_type
(
param
):
return
torch
.
float
if
\
self
.
accumulate_allreduce_grads_in_fp32
else
param
.
dtype
# First calculate total number of elements per type.
type_num_elements
=
{}
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
dtype
=
_get_buffer_type
(
param
)
type_num_elements
[
dtype
]
=
type_num_elements
.
get
(
dtype
,
0
)
\
+
param
.
data
.
nelement
()
# Allocate the buffer.
for
dtype
,
num_elements
in
type_num_elements
.
items
():
self
.
_grad_buffers
[
dtype
]
=
MemoryBuffer
(
num_elements
,
dtype
)
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
dtype
=
_get_buffer_type
(
param
)
type_num_elements
[
dtype
]
-=
param
.
data
.
nelement
()
param
.
main_grad
=
self
.
_grad_buffers
[
dtype
].
get
(
param
.
data
.
shape
,
type_num_elements
[
dtype
])
# Backward hook.
# Accumalation function for the gradients. We need
# to store them so they don't go out of scope.
self
.
grad_accs
=
[]
# Loop over all the parameters in the model.
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
# Expand so we get access to grad_fn.
param_tmp
=
param
.
expand_as
(
param
)
# Get the gradient accumulator functtion.
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
grad_acc
.
register_hook
(
self
.
_make_param_hook
(
param
))
self
.
grad_accs
.
append
(
grad_acc
)
def
_make_param_hook
(
self
,
param
):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def
param_hook
(
*
unused
):
# Add the gradient to the buffer.
if
param
.
grad
.
data
is
not
None
:
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
# Now we can deallocate grad memory.
param
.
grad
=
None
return
param_hook
def
zero_grad_buffer
(
self
):
"""Set the grad buffer data to zero. Needs to be called at the
begining of each iteration."""
assert
self
.
_grad_buffers
is
not
None
,
'buffers are not initialized.'
for
_
,
buffer_
in
self
.
_grad_buffers
.
items
():
buffer_
.
zero
()
def
allreduce_gradients
(
self
):
"""Reduce gradients across data parallel ranks."""
# If we have buffers, simply reduce the data in the buffer.
if
self
.
_grad_buffers
is
not
None
:
for
_
,
buffer_
in
self
.
_grad_buffers
.
items
():
buffer_
.
data
/=
mpu
.
get_data_parallel_world_size
()
torch
.
distributed
.
all_reduce
(
buffer_
.
data
,
group
=
mpu
.
get_data_parallel_group
())
else
:
# Otherwise, bucketize and all-reduce
buckets
=
{}
# Pack the buckets.
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
tp
=
param
.
data
.
type
()
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
param
)
param
.
main_grad
=
param
.
grad
# For each bucket, all-reduce and copy all-reduced grads.
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
/=
mpu
.
get_data_parallel_world_size
()
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
mpu
.
get_data_parallel_group
())
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
megatron/model/module.py
View file @
b4bc51b1
...
@@ -25,6 +25,7 @@ from megatron import mpu
...
@@ -25,6 +25,7 @@ from megatron import mpu
_FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
_FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
_HALF_TYPES
=
(
torch
.
HalfTensor
,
torch
.
cuda
.
HalfTensor
)
_HALF_TYPES
=
(
torch
.
HalfTensor
,
torch
.
cuda
.
HalfTensor
)
_BF16_TYPES
=
(
torch
.
BFloat16Tensor
,
torch
.
cuda
.
BFloat16Tensor
)
...
@@ -109,6 +110,7 @@ class MegatronModule(torch.nn.Module):
...
@@ -109,6 +110,7 @@ class MegatronModule(torch.nn.Module):
"this needs to be handled manually. If you are training "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
"something is definitely wrong."
)
def
conversion_helper
(
val
,
conversion
):
def
conversion_helper
(
val
,
conversion
):
"""Apply conversion to val. Recursively apply conversion if `val`
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
#is a nested tuple/list structure."""
...
@@ -120,44 +122,56 @@ def conversion_helper(val, conversion):
...
@@ -120,44 +122,56 @@ def conversion_helper(val, conversion):
return
rtn
return
rtn
def
fp32_to_f
p
16
(
val
):
def
fp32_to_f
loat
16
(
val
,
float16_convertor
):
"""Convert fp32 `val` to fp16"""
"""Convert fp32 `val` to fp16
/bf16
"""
def
half_conversion
(
val
):
def
half_conversion
(
val
):
val_typecheck
=
val
val_typecheck
=
val
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
val_typecheck
=
val
.
data
val_typecheck
=
val
.
data
if
isinstance
(
val_typecheck
,
_FLOAT_TYPES
):
if
isinstance
(
val_typecheck
,
_FLOAT_TYPES
):
val
=
val
.
half
(
)
val
=
float16_convertor
(
val
)
return
val
return
val
return
conversion_helper
(
val
,
half_conversion
)
return
conversion_helper
(
val
,
half_conversion
)
def
f
p
16_to_fp32
(
val
):
def
f
loat
16_to_fp32
(
val
):
"""Convert fp16 `val` to fp32"""
"""Convert fp16
/bf16
`val` to fp32"""
def
float_conversion
(
val
):
def
float_conversion
(
val
):
val_typecheck
=
val
val_typecheck
=
val
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
val_typecheck
=
val
.
data
val_typecheck
=
val
.
data
if
isinstance
(
val_typecheck
,
_HALF_TYPES
):
if
isinstance
(
val_typecheck
,
(
_BF16_TYPES
,
_HALF_TYPES
)
)
:
val
=
val
.
float
()
val
=
val
.
float
()
return
val
return
val
return
conversion_helper
(
val
,
float_conversion
)
return
conversion_helper
(
val
,
float_conversion
)
class
FP16Module
(
MegatronModule
):
class
Float16Module
(
MegatronModule
):
def
__init__
(
self
,
module
,
args
):
super
(
Float16Module
,
self
).
__init__
()
if
args
.
fp16
:
self
.
add_module
(
'module'
,
module
.
half
())
def
float16_convertor
(
val
):
return
val
.
half
()
elif
args
.
bf16
:
self
.
add_module
(
'module'
,
module
.
bfloat16
())
def
float16_convertor
(
val
):
return
val
.
bfloat16
()
else
:
raise
Exception
(
'should not be here'
)
def
__init__
(
self
,
module
):
self
.
float16_convertor
=
float16_convertor
super
(
FP16Module
,
self
).
__init__
()
self
.
add_module
(
'module'
,
module
.
half
())
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
inputs
=
fp32_to_f
p
16
(
inputs
)
inputs
=
fp32_to_f
loat
16
(
inputs
,
self
.
float16_convertor
)
outputs
=
self
.
module
(
*
inputs
,
**
kwargs
)
outputs
=
self
.
module
(
*
inputs
,
**
kwargs
)
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
outputs
=
f
p
16_to_fp32
(
outputs
)
outputs
=
f
loat
16_to_fp32
(
outputs
)
return
outputs
return
outputs
...
...
megatron/model/transformer.py
View file @
b4bc51b1
...
@@ -397,8 +397,11 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -397,8 +397,11 @@ class ParallelTransformerLayer(MegatronModule):
self
.
apply_residual_connection_post_layernorm
\
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
=
args
.
apply_residual_connection_post_layernorm
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Layernorm on the input data.
# Layernorm on the input data.
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
LayerNorm
=
import_layernorm
(
self
.
fp32_residual_connection
,
self
.
bf16
)
self
.
input_layernorm
=
LayerNorm
(
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
)
...
@@ -440,6 +443,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -440,6 +443,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm at the beginning of the transformer layer.
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
layernorm_output
=
layernorm_output
.
bfloat16
()
# Self attention.
# Self attention.
attention_output
,
attention_bias
=
\
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
self
.
self_attention
(
layernorm_output
,
...
@@ -478,6 +483,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -478,6 +483,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
layernorm_output
=
layernorm_output
.
bfloat16
()
if
self
.
layer_type
==
LayerType
.
decoder
:
if
self
.
layer_type
==
LayerType
.
decoder
:
attention_output
,
attention_bias
=
\
attention_output
,
attention_bias
=
\
...
@@ -500,6 +507,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -500,6 +507,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the decoder attention
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
layernorm_output
=
layernorm_output
.
bfloat16
()
# MLP.
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
...
@@ -533,6 +542,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -533,6 +542,7 @@ class ParallelTransformer(MegatronModule):
super
(
ParallelTransformer
,
self
).
__init__
()
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
...
@@ -578,7 +588,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -578,7 +588,8 @@ class ParallelTransformer(MegatronModule):
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
# Final layer norm before output.
# Final layer norm before output.
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
LayerNorm
=
import_layernorm
(
self
.
fp32_residual_connection
,
self
.
bf16
)
self
.
final_layernorm
=
LayerNorm
(
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
)
...
@@ -665,6 +676,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -665,6 +676,8 @@ class ParallelTransformer(MegatronModule):
# Reverting data format change [s b h] --> [b s h].
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
self
.
final_layernorm
(
hidden_states
)
output
=
self
.
final_layernorm
(
hidden_states
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
output
=
output
.
bfloat16
()
else
:
else
:
output
=
hidden_states
output
=
hidden_states
if
get_key_value
:
if
get_key_value
:
...
...
megatron/optimizer/__init__.py
View file @
b4bc51b1
...
@@ -20,7 +20,7 @@ from megatron import get_args
...
@@ -20,7 +20,7 @@ from megatron import get_args
from
megatron.model
import
import_layernorm
from
megatron.model
import
import_layernorm
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.optimizer
import
F
P
16OptimizerWithF
P
16Params
,
FP32Optimizer
from
.optimizer
import
F
loat
16OptimizerWithF
loat
16Params
,
FP32Optimizer
def
_get_params_for_weight_decay_optimization
(
modules
):
def
_get_params_for_weight_decay_optimization
(
modules
):
...
@@ -28,7 +28,7 @@ def _get_params_for_weight_decay_optimization(modules):
...
@@ -28,7 +28,7 @@ def _get_params_for_weight_decay_optimization(modules):
Layernorms and baises will have no weight decay but the rest will.
Layernorms and baises will have no weight decay but the rest will.
"""
"""
args
=
get_args
()
args
=
get_args
()
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
,
args
.
bf16
)
weight_decay_params
=
{
'params'
:
[]}
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
...
@@ -67,24 +67,45 @@ def get_megatron_optimizer(model):
...
@@ -67,24 +67,45 @@ def get_megatron_optimizer(model):
momentum
=
args
.
sgd_momentum
)
momentum
=
args
.
sgd_momentum
)
else
:
else
:
raise
Exception
(
'{} optimizer is not supported.'
.
format
(
raise
Exception
(
'{} optimizer is not supported.'
.
format
(
args
.
optimizer
))
args
.
optimizer
))
if
args
.
fp16
:
# Determine whether the params have main-grad field.
params_have_main_grad
=
False
if
args
.
DDP_impl
==
'local'
:
params_have_main_grad
=
True
if
args
.
fp16
or
args
.
bf16
:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler
=
None
# Constant loss scale.
# Constant loss scale.
if
args
.
loss_scale
:
if
args
.
loss_scale
:
grad_scaler
=
ConstantGradScaler
(
args
.
loss_scale
)
grad_scaler
=
ConstantGradScaler
(
args
.
loss_scale
)
# Dynamic loss scale.
# Dynamic loss scale.
else
:
else
:
grad_scaler
=
DynamicGradScaler
(
if
args
.
fp16
:
initial_scale
=
args
.
initial_loss_scale
,
grad_scaler
=
DynamicGradScaler
(
min_scale
=
args
.
min_loss_scale
,
initial_scale
=
args
.
initial_loss_scale
,
growth_factor
=
2.0
,
min_scale
=
args
.
min_loss_scale
,
backoff_factor
=
0.5
,
growth_factor
=
2.0
,
growth_interval
=
args
.
loss_scale_window
,
backoff_factor
=
0.5
,
hysteresis
=
args
.
hysteresis
)
growth_interval
=
args
.
loss_scale_window
,
hysteresis
=
args
.
hysteresis
)
# Megatron optimizer.
# Megatron optimizer.
return
FP16OptimizerWithFP16Params
(
optimizer
,
grad_scaler
,
return
Float16OptimizerWithFloat16Params
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
)
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
bf16
,
grad_scaler
)
# FP32.
# FP32.
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
)
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
)
megatron/optimizer/optimizer.py
View file @
b4bc51b1
This diff is collapsed.
Click to expand it.
megatron/training.py
View file @
b4bc51b1
...
@@ -37,9 +37,8 @@ from megatron import print_rank_0
...
@@ -37,9 +37,8 @@ from megatron import print_rank_0
from
megatron
import
print_rank_last
from
megatron
import
print_rank_last
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.model
import
F
P
16Module
from
megatron.model
import
F
loat
16Module
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.learning_rates
import
AnnealingLR
from
megatron.learning_rates
import
AnnealingLR
...
@@ -54,6 +53,7 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving
...
@@ -54,6 +53,7 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
def
print_datetime
(
string
):
def
print_datetime
(
string
):
"""Note that this call will sync across all ranks."""
"""Note that this call will sync across all ranks."""
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
...
@@ -222,8 +222,18 @@ def get_model(model_provider_func):
...
@@ -222,8 +222,18 @@ def get_model(model_provider_func):
model_module
.
cuda
(
torch
.
cuda
.
current_device
())
model_module
.
cuda
(
torch
.
cuda
.
current_device
())
# Fp16 conversion.
# Fp16 conversion.
if
args
.
fp16
:
if
args
.
fp16
or
args
.
bf16
:
model
=
[
FP16Module
(
model_module
)
for
model_module
in
model
]
model
=
[
Float16Module
(
model_module
,
args
)
for
model_module
in
model
]
# For now, the layer norm does not support input float32 and outut bf16.
# For this, we move layernorm parameters to fp32 and cast output of the
# layernorm operation back to bf16.
if
args
.
bf16
and
args
.
fp32_residual_connection
:
from
megatron.model
import
import_layernorm
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
,
args
.
bf16
)
for
model_
in
model
:
for
module_
in
model_
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
module_
.
float
()
if
args
.
DDP_impl
==
'torch'
:
if
args
.
DDP_impl
==
'torch'
:
i
=
torch
.
cuda
.
current_device
()
i
=
torch
.
cuda
.
current_device
()
...
@@ -231,8 +241,12 @@ def get_model(model_provider_func):
...
@@ -231,8 +241,12 @@ def get_model(model_provider_func):
process_group
=
mpu
.
get_data_parallel_group
())
process_group
=
mpu
.
get_data_parallel_group
())
for
model_module
in
model
]
for
model_module
in
model
]
return
model
return
model
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
model
=
[
LocalDDP
(
model_module
)
for
model_module
in
model
]
model
=
[
LocalDDP
(
model_module
,
args
.
accumulate_allreduce_grads_in_fp32
,
args
.
use_contiguous_buffers_in_ddp
)
for
model_module
in
model
]
return
model
return
model
raise
NotImplementedError
(
'Unknown DDP implementation specified: {}. '
raise
NotImplementedError
(
'Unknown DDP implementation specified: {}. '
...
@@ -289,7 +303,7 @@ def setup_model_and_optimizer(model_provider_func):
...
@@ -289,7 +303,7 @@ def setup_model_and_optimizer(model_provider_func):
model
=
get_model
(
model_provider_func
)
model
=
get_model
(
model_provider_func
)
unwrapped_model
=
unwrap_model
(
model
,
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
F
P
16Module
))
(
torchDDP
,
LocalDDP
,
F
loat
16Module
))
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
...
@@ -308,9 +322,7 @@ def setup_model_and_optimizer(model_provider_func):
...
@@ -308,9 +322,7 @@ def setup_model_and_optimizer(model_provider_func):
args
.
iteration
=
0
args
.
iteration
=
0
# We only support local DDP with multiple micro-batches.
# We only support local DDP with multiple micro-batches.
if
len
(
model
)
>
1
:
if
len
(
model
)
>
1
or
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
assert
args
.
DDP_impl
==
'local'
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
assert
args
.
DDP_impl
==
'local'
assert
args
.
DDP_impl
==
'local'
# get model without FP16 and/or TorchDDP wrappers
# get model without FP16 and/or TorchDDP wrappers
...
@@ -331,7 +343,11 @@ def train_step(forward_step_func, data_iterator,
...
@@ -331,7 +343,11 @@ def train_step(forward_step_func, data_iterator,
timers
=
get_timers
()
timers
=
get_timers
()
# Set grad to zero.
# Set grad to zero.
optimizer
.
zero_grad
()
if
args
.
DDP_impl
==
'local'
and
args
.
use_contiguous_buffers_in_ddp
:
for
partition
in
model
:
partition
.
zero_grad_buffer
()
else
:
optimizer
.
zero_grad
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
...
@@ -351,8 +367,7 @@ def train_step(forward_step_func, data_iterator,
...
@@ -351,8 +367,7 @@ def train_step(forward_step_func, data_iterator,
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
timers
(
'backward-params-all-reduce'
).
start
()
timers
(
'backward-params-all-reduce'
).
start
()
for
model_module
in
model
:
for
model_module
in
model
:
model_module
.
allreduce_params
(
reduce_after
=
False
,
model_module
.
allreduce_gradients
()
fp32_allreduce
=
args
.
fp32_allreduce
)
timers
(
'backward-params-all-reduce'
).
stop
()
timers
(
'backward-params-all-reduce'
).
stop
()
# All-reduce word_embeddings' grad across first and last stages to ensure
# All-reduce word_embeddings' grad across first and last stages to ensure
...
@@ -368,12 +383,15 @@ def train_step(forward_step_func, data_iterator,
...
@@ -368,12 +383,15 @@ def train_step(forward_step_func, data_iterator,
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
unwrapped_model
=
model
[
-
1
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
F
P
16Module
))
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
F
loat
16Module
))
if
unwrapped_model
.
share_word_embeddings
:
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
torch
.
distributed
.
all_reduce
(
word_embeddings_weight
.
grad
,
if
args
.
DDP_impl
==
'local'
:
group
=
mpu
.
get_embedding_group
())
grad
=
word_embeddings_weight
.
main_grad
else
:
grad
=
word_embeddings_weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
timers
(
'backward-embedding-all-reduce'
).
stop
()
timers
(
'backward-embedding-all-reduce'
).
stop
()
# Update parameters.
# Update parameters.
...
...
megatron/utils.py
View file @
b4bc51b1
...
@@ -48,13 +48,20 @@ def unwrap_model(model, module_instances=(torchDDP)):
...
@@ -48,13 +48,20 @@ def unwrap_model(model, module_instances=(torchDDP)):
def
calc_params_l2_norm
(
model
):
def
calc_params_l2_norm
(
model
):
"""Calculate l2 norm of parameters """
"""Calculate l2 norm of parameters """
args
=
get_args
()
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
# Remove duplicate params.
# Remove duplicate params.
params_data
=
[]
params_data
=
[]
for
param
in
model
.
parameters
():
for
model_
in
model
:
is_not_shared
=
param_is_not_shared
(
param
)
for
param
in
model_
.
parameters
():
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
is_not_shared
=
param_is_not_shared
(
param
)
if
is_not_shared
and
is_not_tp_duplicate
:
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
params_data
.
append
(
param
.
data
)
if
is_not_shared
and
is_not_tp_duplicate
:
if
args
.
bf16
:
params_data
.
append
(
param
.
data
.
float
())
else
:
params_data
.
append
(
param
.
data
)
# Calculate norm
# Calculate norm
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
norm
,
_
=
multi_tensor_applier
(
norm
,
_
=
multi_tensor_applier
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment