Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
1a2cb60c
Commit
1a2cb60c
authored
Mar 08, 2021
by
Jared Casper
Browse files
Merge branch 'bfloat_with_fp32_grad_acc' into 'main'
Bfloat with fp32 grad acc See merge request ADLR/megatron-lm!247
parents
87b8b9dc
b4bc51b1
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
504 additions
and
211 deletions
+504
-211
megatron/arguments.py
megatron/arguments.py
+25
-4
megatron/model/__init__.py
megatron/model/__init__.py
+5
-3
megatron/model/bert_model.py
megatron/model/bert_model.py
+1
-1
megatron/model/distributed.py
megatron/model/distributed.py
+178
-72
megatron/model/module.py
megatron/model/module.py
+26
-12
megatron/model/transformer.py
megatron/model/transformer.py
+15
-2
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+35
-14
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+174
-83
megatron/training.py
megatron/training.py
+33
-15
megatron/utils.py
megatron/utils.py
+12
-5
No files found.
megatron/arguments.py
View file @
1a2cb60c
...
@@ -129,11 +129,26 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -129,11 +129,26 @@ def parse_args(extra_args_provider=None, defaults={},
# Parameters dtype.
# Parameters dtype.
args
.
params_dtype
=
torch
.
float
args
.
params_dtype
=
torch
.
float
if
args
.
fp16
:
if
args
.
fp16
:
assert
not
args
.
bf16
args
.
params_dtype
=
torch
.
half
args
.
params_dtype
=
torch
.
half
if
args
.
bf16
:
assert
not
args
.
fp16
args
.
params_dtype
=
torch
.
bfloat16
# No fusion is support for bfloat for now
assert
not
args
.
masked_softmax_fusion
assert
not
args
.
bias_gelu_fusion
assert
not
args
.
bias_dropout_fusion
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
'using {} for parameters ...'
.
format
(
args
.
params_dtype
),
print
(
'using {} for parameters ...'
.
format
(
args
.
params_dtype
),
flush
=
True
)
flush
=
True
)
# If we do accumulation and all-reduces in fp32, we need to have
# local DDP and we should set the use-contiguous-buffers-in-ddp.
if
args
.
accumulate_allreduce_grads_in_fp32
:
assert
args
.
DDP_impl
==
'local'
args
.
use_contiguous_buffers_in_ddp
=
True
if
args
.
dataloader_type
is
None
:
if
args
.
dataloader_type
is
None
:
args
.
dataloader_type
=
'single'
args
.
dataloader_type
=
'single'
...
@@ -204,8 +219,8 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -204,8 +219,8 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
fp16_lm_cross_entropy
:
if
args
.
fp16_lm_cross_entropy
:
assert
args
.
fp16
,
'lm cross entropy in fp16 only support in fp16 mode.'
assert
args
.
fp16
,
'lm cross entropy in fp16 only support in fp16 mode.'
if
args
.
fp32_residual_connection
:
if
args
.
fp32_residual_connection
:
assert
args
.
fp16
,
\
assert
args
.
fp16
or
args
.
bf16
,
\
'residual connection in fp32 only supported when using fp16.'
'residual connection in fp32 only supported when using fp16
or bf16
.'
# Activation checkpointing.
# Activation checkpointing.
if
args
.
distribute_checkpointed_activations
:
if
args
.
distribute_checkpointed_activations
:
assert
args
.
checkpoint_activations
,
\
assert
args
.
checkpoint_activations
,
\
...
@@ -528,6 +543,8 @@ def _add_mixed_precision_args(parser):
...
@@ -528,6 +543,8 @@ def _add_mixed_precision_args(parser):
group
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
group
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'Run model in fp16 mode.'
)
help
=
'Run model in fp16 mode.'
)
group
.
add_argument
(
'--bf16'
,
action
=
'store_true'
,
help
=
'Run model in bfloat16 mode.'
)
group
.
add_argument
(
'--loss-scale'
,
type
=
float
,
default
=
None
,
group
.
add_argument
(
'--loss-scale'
,
type
=
float
,
default
=
None
,
help
=
'Static loss scaling, positive power of 2 '
help
=
'Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic'
'values can improve fp16 convergence. If None, dynamic'
...
@@ -549,8 +566,9 @@ def _add_mixed_precision_args(parser):
...
@@ -549,8 +566,9 @@ def _add_mixed_precision_args(parser):
help
=
'Run attention masking and softmax in fp32. '
help
=
'Run attention masking and softmax in fp32. '
'This flag is ignored unless '
'This flag is ignored unless '
'--no-query-key-layer-scaling is specified.'
)
'--no-query-key-layer-scaling is specified.'
)
group
.
add_argument
(
'--fp32-allreduce'
,
action
=
'store_true'
,
group
.
add_argument
(
'--accumulate-allreduce-grads-in-fp32'
,
help
=
'All-reduce in fp32'
)
action
=
'store_true'
,
help
=
'Gradient accumulation and all-reduce in fp32.'
)
group
.
add_argument
(
'--fp16-lm-cross-entropy'
,
action
=
'store_true'
,
group
.
add_argument
(
'--fp16-lm-cross-entropy'
,
action
=
'store_true'
,
help
=
'Move the cross entropy unreduced loss calculation'
help
=
'Move the cross entropy unreduced loss calculation'
'for lm head to fp16.'
)
'for lm head to fp16.'
)
...
@@ -577,6 +595,9 @@ def _add_distributed_args(parser):
...
@@ -577,6 +595,9 @@ def _add_distributed_args(parser):
choices
=
[
'local'
,
'torch'
],
choices
=
[
'local'
,
'torch'
],
help
=
'which DistributedDataParallel implementation '
help
=
'which DistributedDataParallel implementation '
'to use.'
)
'to use.'
)
group
.
add_argument
(
'--use-contiguous-buffers-in-ddp'
,
action
=
'store_true'
,
help
=
'If set, use contiguous buffer in DDP. Note that '
'this option only works woth local DDP.'
)
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
dest
=
'scatter_gather_tensors_in_pipeline'
)
dest
=
'scatter_gather_tensors_in_pipeline'
)
...
...
megatron/model/__init__.py
View file @
1a2cb60c
...
@@ -16,11 +16,13 @@
...
@@ -16,11 +16,13 @@
_LAYER_NORM
=
None
_LAYER_NORM
=
None
def
import_layernorm
(
fp32_residual_connection
):
def
import_layernorm
(
fp32_residual_connection
,
bf16
):
global
_LAYER_NORM
global
_LAYER_NORM
if
not
_LAYER_NORM
:
if
not
_LAYER_NORM
:
if
fp32_residual_connection
:
if
bf16
:
from
torch.nn
import
LayerNorm
elif
fp32_residual_connection
:
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
else
:
else
:
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
as
LayerNorm
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
as
LayerNorm
...
@@ -39,6 +41,6 @@ from .gpt_model import (GPTModel,
...
@@ -39,6 +41,6 @@ from .gpt_model import (GPTModel,
GPTModelIntermediateStage
,
GPTModelIntermediateStage
,
GPTModelLastStage
)
GPTModelLastStage
)
from
.language_model
import
get_language_model
from
.language_model
import
get_language_model
from
.module
import
F
P
16Module
from
.module
import
F
loat
16Module
megatron/model/bert_model.py
View file @
1a2cb60c
...
@@ -78,7 +78,7 @@ class BertLMHead(MegatronModule):
...
@@ -78,7 +78,7 @@ class BertLMHead(MegatronModule):
self
.
parallel_output
=
parallel_output
self
.
parallel_output
=
parallel_output
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
,
args
.
bf16
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
if
args
.
openai_gelu
:
if
args
.
openai_gelu
:
...
...
megatron/model/distributed.py
View file @
1a2cb60c
...
@@ -13,100 +13,206 @@
...
@@ -13,100 +13,206 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
abc
import
ABC
from
abc
import
abstractmethod
import
torch
import
torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
import
torch.distributed
as
dist
from
torch.nn.modules
import
Module
from
torch.autograd
import
Variable
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
class
DistributedDataParallel
(
MegatronModule
):
def
__init__
(
self
,
module
):
class
MemoryBuffer
:
super
(
DistributedDataParallel
,
self
).
__init__
()
self
.
warn_on_half
=
True
if
dist
.
_backend
==
dist
.
dist_backend
.
GLOO
else
False
def
__init__
(
self
,
numel
,
dtype
):
self
.
numel
=
numel
self
.
dtype
=
dtype
self
.
data
=
torch
.
zeros
(
self
.
numel
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
def
zero
(
self
):
"""Reset the buffer to zero."""
self
.
data
.
zero_
()
def
get
(
self
,
shape
,
start_index
):
"""Return a tensor with the input `shape` as a view into the
1-D data starting at `start_index`."""
end_index
=
start_index
+
shape
.
numel
()
assert
end_index
<=
self
.
numel
,
\
'requested tensor is out of the buffer range.'
buffer_tensor
=
self
.
data
[
start_index
:
end_index
]
buffer_tensor
=
buffer_tensor
.
view
(
shape
)
return
buffer_tensor
class
DistributedDataParallelBase
(
MegatronModule
,
ABC
):
"""Abstract class for DDP."""
def
__init__
(
self
,
module
):
super
(
DistributedDataParallelBase
,
self
).
__init__
()
# Keep a pointer to the model.
self
.
module
=
module
self
.
module
=
module
self
.
data_parallel_group
=
mpu
.
get_data_parallel_group
()
def
allreduce_params
(
reduce_after
=
True
,
no_scale
=
False
,
fp32_allreduce
=
False
):
if
(
self
.
needs_reduction
):
@
abstractmethod
self
.
needs_reduction
=
False
def
allreduce_gradients
(
self
):
buckets
=
{}
pass
for
name
,
param
in
self
.
module
.
named_parameters
():
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
tp
=
(
param
.
data
.
type
())
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
param
)
if
self
.
warn_on_half
:
if
torch
.
cuda
.
HalfTensor
in
buckets
:
print
(
"WARNING: gloo dist backend for half parameters may be extremely slow."
+
" It is recommended to use the NCCL backend in this case."
)
self
.
warn_on_half
=
False
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
if
fp32_allreduce
:
coalesced
=
coalesced
.
float
()
if
not
no_scale
and
not
reduce_after
:
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
dist
.
all_reduce
(
coalesced
,
group
=
self
.
data_parallel_group
)
torch
.
cuda
.
synchronize
()
if
not
no_scale
and
reduce_after
:
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
self
.
hook_handles
=
[]
self
.
hooks
=
[]
for
param
in
list
(
self
.
module
.
parameters
()):
def
allreduce_hook
(
*
unused
):
Variable
.
_execution_engine
.
queue_callback
(
allreduce_params
)
# handle = param.register_hook(allreduce_hook)
# self.hooks.append(allreduce_hook)
# self.hook_handles.append(handle)
self
.
allreduce_params
=
allreduce_params
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
self
.
needs_reduction
=
True
return
self
.
module
(
*
inputs
,
**
kwargs
)
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
#[h.remove() for h in self.hook_handles]
return
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
sd
=
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
# for handle, hook in zip(self.hook_handles, self.hooks):
# d = handle.hooks_dict_ref()
# d[handle.id] = hook
return
sd
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
return
self
.
module
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
return
self
.
module
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
keep_vars
)
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
'''
def _sync_buffers(self):
buffers = list(self.module._all_buffers())
class
DistributedDataParallel
(
DistributedDataParallelBase
):
if len(buffers) > 0:
"""DDP with contiguous buffers options to storre and accumulate gradients.
# cross-node buffer sync
This class:
flat_buffers = _flatten_dense_tensors(buffers)
- has the potential to reduce memory fragmentation.
dist.broadcast(flat_buffers, 0)
- provides the option to do the gradient accumulation
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
in a type other than the params type (for example fp32)
Arguments:
module: input model.
accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
and the gradient all-reduce all in in float32. If this option is
true, we require `use_contiguous_buffers` to be true too.
use_contiguous_buffers: if true, use a contiguous buffer to store the
gradients.
"""
def
__init__
(
self
,
module
,
accumulate_allreduce_grads_in_fp32
,
use_contiguous_buffers
):
super
(
DistributedDataParallel
,
self
).
__init__
(
module
)
self
.
accumulate_allreduce_grads_in_fp32
\
=
accumulate_allreduce_grads_in_fp32
self
.
use_contiguous_buffers
=
use_contiguous_buffers
# If we are using fp32-accumulate-allreduce explicitly
# this means we need main grads in a continous buffer.
if
self
.
accumulate_allreduce_grads_in_fp32
:
assert
self
.
use_contiguous_buffers
# ===================================
# Rest of this part applies only to
# the case we use continuous buffers.
# ===================================
self
.
_grad_buffers
=
None
if
self
.
use_contiguous_buffers
:
self
.
_grad_buffers
=
{}
# Simple function to define buffer type.
def
_get_buffer_type
(
param
):
return
torch
.
float
if
\
self
.
accumulate_allreduce_grads_in_fp32
else
param
.
dtype
# First calculate total number of elements per type.
type_num_elements
=
{}
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
dtype
=
_get_buffer_type
(
param
)
type_num_elements
[
dtype
]
=
type_num_elements
.
get
(
dtype
,
0
)
\
+
param
.
data
.
nelement
()
# Allocate the buffer.
for
dtype
,
num_elements
in
type_num_elements
.
items
():
self
.
_grad_buffers
[
dtype
]
=
MemoryBuffer
(
num_elements
,
dtype
)
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
dtype
=
_get_buffer_type
(
param
)
type_num_elements
[
dtype
]
-=
param
.
data
.
nelement
()
param
.
main_grad
=
self
.
_grad_buffers
[
dtype
].
get
(
param
.
data
.
shape
,
type_num_elements
[
dtype
])
# Backward hook.
# Accumalation function for the gradients. We need
# to store them so they don't go out of scope.
self
.
grad_accs
=
[]
# Loop over all the parameters in the model.
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
# Expand so we get access to grad_fn.
param_tmp
=
param
.
expand_as
(
param
)
# Get the gradient accumulator functtion.
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
grad_acc
.
register_hook
(
self
.
_make_param_hook
(
param
))
self
.
grad_accs
.
append
(
grad_acc
)
def
_make_param_hook
(
self
,
param
):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def
param_hook
(
*
unused
):
# Add the gradient to the buffer.
if
param
.
grad
.
data
is
not
None
:
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
# Now we can deallocate grad memory.
param
.
grad
=
None
return
param_hook
def
zero_grad_buffer
(
self
):
"""Set the grad buffer data to zero. Needs to be called at the
begining of each iteration."""
assert
self
.
_grad_buffers
is
not
None
,
'buffers are not initialized.'
for
_
,
buffer_
in
self
.
_grad_buffers
.
items
():
buffer_
.
zero
()
def
allreduce_gradients
(
self
):
"""Reduce gradients across data parallel ranks."""
# If we have buffers, simply reduce the data in the buffer.
if
self
.
_grad_buffers
is
not
None
:
for
_
,
buffer_
in
self
.
_grad_buffers
.
items
():
buffer_
.
data
/=
mpu
.
get_data_parallel_world_size
()
torch
.
distributed
.
all_reduce
(
buffer_
.
data
,
group
=
mpu
.
get_data_parallel_group
())
else
:
# Otherwise, bucketize and all-reduce
buckets
=
{}
# Pack the buckets.
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
tp
=
param
.
data
.
type
()
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
param
)
param
.
main_grad
=
param
.
grad
# For each bucket, all-reduce and copy all-reduced grads.
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
/=
mpu
.
get_data_parallel_world_size
()
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
mpu
.
get_data_parallel_group
())
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
buf
.
copy_
(
synced
)
def train(self, mode=True):
# Clear NCCL communicator and CUDA event cache of the default group ID,
# These cache will be recreated at the later call. This is currently a
# work-around for a potential NCCL deadlock.
if dist._backend == dist.dist_backend.NCCL:
dist._clear_group_cache()
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
'''
megatron/model/module.py
View file @
1a2cb60c
...
@@ -25,6 +25,7 @@ from megatron import mpu
...
@@ -25,6 +25,7 @@ from megatron import mpu
_FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
_FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
_HALF_TYPES
=
(
torch
.
HalfTensor
,
torch
.
cuda
.
HalfTensor
)
_HALF_TYPES
=
(
torch
.
HalfTensor
,
torch
.
cuda
.
HalfTensor
)
_BF16_TYPES
=
(
torch
.
BFloat16Tensor
,
torch
.
cuda
.
BFloat16Tensor
)
...
@@ -109,6 +110,7 @@ class MegatronModule(torch.nn.Module):
...
@@ -109,6 +110,7 @@ class MegatronModule(torch.nn.Module):
"this needs to be handled manually. If you are training "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
"something is definitely wrong."
)
def
conversion_helper
(
val
,
conversion
):
def
conversion_helper
(
val
,
conversion
):
"""Apply conversion to val. Recursively apply conversion if `val`
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
#is a nested tuple/list structure."""
...
@@ -120,44 +122,56 @@ def conversion_helper(val, conversion):
...
@@ -120,44 +122,56 @@ def conversion_helper(val, conversion):
return
rtn
return
rtn
def
fp32_to_f
p
16
(
val
):
def
fp32_to_f
loat
16
(
val
,
float16_convertor
):
"""Convert fp32 `val` to fp16"""
"""Convert fp32 `val` to fp16
/bf16
"""
def
half_conversion
(
val
):
def
half_conversion
(
val
):
val_typecheck
=
val
val_typecheck
=
val
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
val_typecheck
=
val
.
data
val_typecheck
=
val
.
data
if
isinstance
(
val_typecheck
,
_FLOAT_TYPES
):
if
isinstance
(
val_typecheck
,
_FLOAT_TYPES
):
val
=
val
.
half
(
)
val
=
float16_convertor
(
val
)
return
val
return
val
return
conversion_helper
(
val
,
half_conversion
)
return
conversion_helper
(
val
,
half_conversion
)
def
f
p
16_to_fp32
(
val
):
def
f
loat
16_to_fp32
(
val
):
"""Convert fp16 `val` to fp32"""
"""Convert fp16
/bf16
`val` to fp32"""
def
float_conversion
(
val
):
def
float_conversion
(
val
):
val_typecheck
=
val
val_typecheck
=
val
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
val_typecheck
=
val
.
data
val_typecheck
=
val
.
data
if
isinstance
(
val_typecheck
,
_HALF_TYPES
):
if
isinstance
(
val_typecheck
,
(
_BF16_TYPES
,
_HALF_TYPES
)
)
:
val
=
val
.
float
()
val
=
val
.
float
()
return
val
return
val
return
conversion_helper
(
val
,
float_conversion
)
return
conversion_helper
(
val
,
float_conversion
)
class
FP16Module
(
MegatronModule
):
class
Float16Module
(
MegatronModule
):
def
__init__
(
self
,
module
,
args
):
super
(
Float16Module
,
self
).
__init__
()
def
__init__
(
self
,
module
):
if
args
.
fp16
:
super
(
FP16Module
,
self
).
__init__
()
self
.
add_module
(
'module'
,
module
.
half
())
self
.
add_module
(
'module'
,
module
.
half
())
def
float16_convertor
(
val
):
return
val
.
half
()
elif
args
.
bf16
:
self
.
add_module
(
'module'
,
module
.
bfloat16
())
def
float16_convertor
(
val
):
return
val
.
bfloat16
()
else
:
raise
Exception
(
'should not be here'
)
self
.
float16_convertor
=
float16_convertor
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
inputs
=
fp32_to_f
p
16
(
inputs
)
inputs
=
fp32_to_f
loat
16
(
inputs
,
self
.
float16_convertor
)
outputs
=
self
.
module
(
*
inputs
,
**
kwargs
)
outputs
=
self
.
module
(
*
inputs
,
**
kwargs
)
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
outputs
=
f
p
16_to_fp32
(
outputs
)
outputs
=
f
loat
16_to_fp32
(
outputs
)
return
outputs
return
outputs
...
...
megatron/model/transformer.py
View file @
1a2cb60c
...
@@ -397,8 +397,11 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -397,8 +397,11 @@ class ParallelTransformerLayer(MegatronModule):
self
.
apply_residual_connection_post_layernorm
\
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
=
args
.
apply_residual_connection_post_layernorm
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Layernorm on the input data.
# Layernorm on the input data.
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
LayerNorm
=
import_layernorm
(
self
.
fp32_residual_connection
,
self
.
bf16
)
self
.
input_layernorm
=
LayerNorm
(
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
)
...
@@ -440,6 +443,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -440,6 +443,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm at the beginning of the transformer layer.
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
layernorm_output
=
layernorm_output
.
bfloat16
()
# Self attention.
# Self attention.
attention_output
,
attention_bias
=
\
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
self
.
self_attention
(
layernorm_output
,
...
@@ -478,6 +483,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -478,6 +483,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
layernorm_output
=
layernorm_output
.
bfloat16
()
if
self
.
layer_type
==
LayerType
.
decoder
:
if
self
.
layer_type
==
LayerType
.
decoder
:
attention_output
,
attention_bias
=
\
attention_output
,
attention_bias
=
\
...
@@ -500,6 +507,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -500,6 +507,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the decoder attention
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
layernorm_output
=
layernorm_output
.
bfloat16
()
# MLP.
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
...
@@ -533,6 +542,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -533,6 +542,7 @@ class ParallelTransformer(MegatronModule):
super
(
ParallelTransformer
,
self
).
__init__
()
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
...
@@ -578,7 +588,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -578,7 +588,8 @@ class ParallelTransformer(MegatronModule):
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
# Final layer norm before output.
# Final layer norm before output.
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
LayerNorm
=
import_layernorm
(
self
.
fp32_residual_connection
,
self
.
bf16
)
self
.
final_layernorm
=
LayerNorm
(
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
)
...
@@ -665,6 +676,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -665,6 +676,8 @@ class ParallelTransformer(MegatronModule):
# Reverting data format change [s b h] --> [b s h].
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
self
.
final_layernorm
(
hidden_states
)
output
=
self
.
final_layernorm
(
hidden_states
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
output
=
output
.
bfloat16
()
else
:
else
:
output
=
hidden_states
output
=
hidden_states
if
get_key_value
:
if
get_key_value
:
...
...
megatron/optimizer/__init__.py
View file @
1a2cb60c
...
@@ -20,7 +20,7 @@ from megatron import get_args
...
@@ -20,7 +20,7 @@ from megatron import get_args
from
megatron.model
import
import_layernorm
from
megatron.model
import
import_layernorm
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.optimizer
import
F
P
16OptimizerWithF
P
16Params
,
FP32Optimizer
from
.optimizer
import
F
loat
16OptimizerWithF
loat
16Params
,
FP32Optimizer
def
_get_params_for_weight_decay_optimization
(
modules
):
def
_get_params_for_weight_decay_optimization
(
modules
):
...
@@ -28,7 +28,7 @@ def _get_params_for_weight_decay_optimization(modules):
...
@@ -28,7 +28,7 @@ def _get_params_for_weight_decay_optimization(modules):
Layernorms and baises will have no weight decay but the rest will.
Layernorms and baises will have no weight decay but the rest will.
"""
"""
args
=
get_args
()
args
=
get_args
()
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
,
args
.
bf16
)
weight_decay_params
=
{
'params'
:
[]}
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
...
@@ -69,12 +69,26 @@ def get_megatron_optimizer(model):
...
@@ -69,12 +69,26 @@ def get_megatron_optimizer(model):
raise
Exception
(
'{} optimizer is not supported.'
.
format
(
raise
Exception
(
'{} optimizer is not supported.'
.
format
(
args
.
optimizer
))
args
.
optimizer
))
if
args
.
fp16
:
# Determine whether the params have main-grad field.
params_have_main_grad
=
False
if
args
.
DDP_impl
==
'local'
:
params_have_main_grad
=
True
if
args
.
fp16
or
args
.
bf16
:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler
=
None
# Constant loss scale.
# Constant loss scale.
if
args
.
loss_scale
:
if
args
.
loss_scale
:
grad_scaler
=
ConstantGradScaler
(
args
.
loss_scale
)
grad_scaler
=
ConstantGradScaler
(
args
.
loss_scale
)
# Dynamic loss scale.
# Dynamic loss scale.
else
:
else
:
if
args
.
fp16
:
grad_scaler
=
DynamicGradScaler
(
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
args
.
initial_loss_scale
,
initial_scale
=
args
.
initial_loss_scale
,
min_scale
=
args
.
min_loss_scale
,
min_scale
=
args
.
min_loss_scale
,
...
@@ -82,9 +96,16 @@ def get_megatron_optimizer(model):
...
@@ -82,9 +96,16 @@ def get_megatron_optimizer(model):
backoff_factor
=
0.5
,
backoff_factor
=
0.5
,
growth_interval
=
args
.
loss_scale_window
,
growth_interval
=
args
.
loss_scale_window
,
hysteresis
=
args
.
hysteresis
)
hysteresis
=
args
.
hysteresis
)
# Megatron optimizer.
# Megatron optimizer.
return
FP16OptimizerWithFP16Params
(
optimizer
,
grad_scaler
,
return
Float16OptimizerWithFloat16Params
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
)
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
bf16
,
grad_scaler
)
# FP32.
# FP32.
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
)
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
)
megatron/optimizer/optimizer.py
View file @
1a2cb60c
...
@@ -46,24 +46,37 @@ def _zero_grad_group_helper(group, set_to_none):
...
@@ -46,24 +46,37 @@ def _zero_grad_group_helper(group, set_to_none):
def
_multi_tensor_copy_this_to_that
(
this
,
that
,
overflow_buf
=
None
):
def
_multi_tensor_copy_this_to_that
(
this
,
that
,
overflow_buf
=
None
):
"""Use multi-tensor-applier to copy values from one list to another."""
"""Use multi-tensor-applier to copy values from one list to another.
We don't have a blfoat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
with bfloat16."""
if
overflow_buf
:
if
overflow_buf
:
overflow_buf
.
fill_
(
0
)
overflow_buf
.
fill_
(
0
)
else
:
overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# Scaling with factor `1.0` is equivalent to copy.
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
overflow_buf
,
overflow_buf
,
[
this
,
that
],
[
this
,
that
],
1.0
)
1.0
)
else
:
for
this_
,
that_
in
zip
(
this
,
that
):
that_
.
copy_
(
this_
)
class
MegatronOptimizer
(
ABC
):
class
MegatronOptimizer
(
ABC
):
def
__init__
(
self
,
optimizer
):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
):
"""Input optimizer is the base optimizer for example Adam."""
"""Input optimizer is the base optimizer for example Adam."""
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
assert
self
.
optimizer
,
'no optimizer is provided.'
assert
self
.
optimizer
,
'no optimizer is provided.'
# Set gradient clipping and logging params.
self
.
clip_grad
=
clip_grad
self
.
log_num_zeros_in_grad
=
log_num_zeros_in_grad
self
.
params_have_main_grad
=
params_have_main_grad
def
get_parameters
(
self
):
def
get_parameters
(
self
):
params
=
[]
params
=
[]
...
@@ -72,31 +85,38 @@ class MegatronOptimizer(ABC):
...
@@ -72,31 +85,38 @@ class MegatronOptimizer(ABC):
params
.
append
(
param
)
params
.
append
(
param
)
return
params
return
params
def
clip_grad_norm
(
self
,
clip_grad
):
def
clip_grad_norm
(
self
,
clip_grad
):
params
=
self
.
get_parameters
()
params
=
self
.
get_parameters
()
return
clip_grad_norm_fp32
(
params
,
clip_grad
)
return
clip_grad_norm_fp32
(
params
,
clip_grad
)
def
count_zeros
(
self
):
def
count_zeros
(
self
):
params
=
self
.
get_parameters
()
params
=
self
.
get_parameters
()
return
count_zeros_fp32
(
params
)
return
count_zeros_fp32
(
params
)
@
abstractmethod
@
abstractmethod
def
zero_grad
(
self
,
set_to_none
=
True
):
def
zero_grad
(
self
,
set_to_none
=
True
):
pass
pass
@
abstractmethod
@
abstractmethod
def
get_loss_scale
(
self
):
def
get_loss_scale
(
self
):
"""The output should be a cuda tensor of size 1."""
"""The output should be a cuda tensor of size 1."""
pass
pass
def
scale_loss
(
self
,
loss
):
def
scale_loss
(
self
,
loss
):
"""Simple scaling."""
"""Simple scaling."""
return
self
.
get_loss_scale
()
*
loss
return
self
.
get_loss_scale
()
*
loss
@
abstractmethod
@
abstractmethod
def
step
(
self
):
def
step
(
self
):
pass
pass
@
abstractmethod
@
abstractmethod
def
reload_model_params
(
self
):
def
reload_model_params
(
self
):
"""Refreshes any internal state from the current model parameters.
"""Refreshes any internal state from the current model parameters.
...
@@ -106,14 +126,17 @@ class MegatronOptimizer(ABC):
...
@@ -106,14 +126,17 @@ class MegatronOptimizer(ABC):
with main parameters, the main parameters need to also be updated."""
with main parameters, the main parameters need to also be updated."""
pass
pass
@
abstractmethod
@
abstractmethod
def
state_dict
(
self
):
def
state_dict
(
self
):
pass
pass
@
abstractmethod
@
abstractmethod
def
load_state_dict
(
self
,
state_dict
):
def
load_state_dict
(
self
,
state_dict
):
pass
pass
# Promote state so it can be retrieved or set via
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
# "optimizer_instance.state"
def
_get_state
(
self
):
def
_get_state
(
self
):
...
@@ -124,6 +147,7 @@ class MegatronOptimizer(ABC):
...
@@ -124,6 +147,7 @@ class MegatronOptimizer(ABC):
state
=
property
(
_get_state
,
_set_state
)
state
=
property
(
_get_state
,
_set_state
)
# Promote param_groups so it can be retrieved or set via
# Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups"
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
# (for example, to adjust the learning rate)
...
@@ -137,50 +161,90 @@ class MegatronOptimizer(ABC):
...
@@ -137,50 +161,90 @@ class MegatronOptimizer(ABC):
class
FP16OptimizerWithFP16Params
(
MegatronOptimizer
):
class
Float16OptimizerWithFloat16Params
(
MegatronOptimizer
):
"""Float16 optimizer for fp16 and bf16 data types.
def
__init__
(
self
,
optimizer
,
grad_scaler
,
clip_grad
,
log_num_zeros_in_grad
):
super
(
FP16OptimizerWithFP16Params
,
self
).
__init__
(
optimizer
)
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a contihuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
"""
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
bf16
,
grad_scaler
):
super
(
Float16OptimizerWithFloat16Params
,
self
).
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
)
self
.
bf16
=
bf16
self
.
grad_scaler
=
grad_scaler
self
.
grad_scaler
=
grad_scaler
self
.
clip_grad
=
clip_grad
# None grad scaler is only supported for bf16.
self
.
log_num_zeros_in_grad
=
log_num_zeros_in_grad
if
self
.
grad_scaler
is
None
:
assert
self
.
bf16
,
'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend.
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
# Any non-zero value indicates inf/nan.
# Note that we keep this for the cases that grad scaler is none.
# We still record nan/inf if we have a bfloat16 with a grad scaler.
if
self
.
grad_scaler
:
self
.
found_inf
=
torch
.
cuda
.
FloatTensor
([
0.0
])
self
.
found_inf
=
torch
.
cuda
.
FloatTensor
([
0.0
])
# Dummy tensor needed for apex multi-apply tensor.
# Dummy tensor needed for apex multi-apply tensor.
# For bfloat, we don't have multi-tensor apply and for now
# we set it to none so the multi-tensor apply gets ignored.
if
bf16
:
self
.
_dummy_overflow_buf
=
None
else
:
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# In case grad scaler is not passed, define the unity scale.
if
self
.
grad_scaler
is
None
:
self
.
_scale_one
=
torch
.
cuda
.
FloatTensor
([
1.0
])
# ======================
# ======================
# main parameter stuff
# main parameter stuff
# ======================
# ======================
# Three groups of parameters:
# Three groups of parameters:
# f
p
16_groups: original f
p
16 parameters
# f
loat
16_groups: original f
loat
16 parameters
# fp32_from_f
p
16_groups: fp32 copy of f
p
16 parameters
# fp32_from_f
loat
16_groups: fp32 copy of f
loat
16 parameters
# fp32_from_fp32_groups: original fp32 parameters
# fp32_from_fp32_groups: original fp32 parameters
self
.
f
p
16_groups
=
[]
self
.
f
loat
16_groups
=
[]
self
.
fp32_from_f
p
16_groups
=
[]
self
.
fp32_from_f
loat
16_groups
=
[]
self
.
fp32_from_fp32_groups
=
[]
self
.
fp32_from_fp32_groups
=
[]
# For all the groups in the original optimizer:
# For all the groups in the original optimizer:
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param_group
in
self
.
optimizer
.
param_groups
:
f
p
16_params_this_group
=
[]
f
loat
16_params_this_group
=
[]
fp32_params_this_group
=
[]
fp32_params_this_group
=
[]
fp32_from_f
p
16_params_this_group
=
[]
fp32_from_f
loat
16_params_this_group
=
[]
# For all the parameters in this group:
# For all the parameters in this group:
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
if
param
.
requires_grad
:
if
param
.
requires_grad
:
# fp16 params:
# float16 params:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
if
param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
fp16_params_this_group
.
append
(
param
)
'torch.cuda.BFloat16Tensor'
]:
float16_params_this_group
.
append
(
param
)
# Create a copy
# Create a copy
main_param
=
param
.
detach
().
clone
().
float
()
main_param
=
param
.
detach
().
clone
().
float
()
# Store grads
main_param
.
requires_grad
=
True
# Copy tensor model parallel attributes.
# Copy tensor model parallel attributes.
mpu
.
copy_tensor_model_parallel_attributes
(
main_param
,
mpu
.
copy_tensor_model_parallel_attributes
(
main_param
,
param
)
param
)
...
@@ -188,7 +252,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -188,7 +252,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
main_param
.
shared
=
param
.
shared
main_param
.
shared
=
param
.
shared
# Replace the optimizer params with the new fp32 copy.
# Replace the optimizer params with the new fp32 copy.
param_group
[
'params'
][
i
]
=
main_param
param_group
[
'params'
][
i
]
=
main_param
fp32_from_f
p
16_params_this_group
.
append
(
main_param
)
fp32_from_f
loat
16_params_this_group
.
append
(
main_param
)
# Reset existing state dict key to the new main param.
# Reset existing state dict key to the new main param.
if
param
in
self
.
optimizer
.
state
:
if
param
in
self
.
optimizer
.
state
:
self
.
optimizer
.
state
[
main_param
]
\
self
.
optimizer
.
state
[
main_param
]
\
...
@@ -200,13 +264,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -200,13 +264,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
param_group
[
'params'
][
i
]
=
param
param_group
[
'params'
][
i
]
=
param
else
:
else
:
raise
TypeError
(
"Wrapped parameters must be either "
raise
TypeError
(
'Wrapped parameters must be one of '
"torch.cuda.FloatTensor or "
'torch.cuda.FloatTensor, '
"torch.cuda.HalfTensor. "
'torch.cuda.HalfTensor, or '
"Received {}"
.
format
(
param
.
type
()))
'torch.cuda.BFloat16Tensor. '
'Received {}'
.
format
(
param
.
type
()))
self
.
fp16_groups
.
append
(
fp16_params_this_group
)
self
.
fp32_from_fp16_groups
.
append
(
fp32_from_fp16_params_this_group
)
self
.
float16_groups
.
append
(
float16_params_this_group
)
self
.
fp32_from_float16_groups
.
append
(
fp32_from_float16_params_this_group
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
# Leverage state_dict() and load_state_dict() to
# Leverage state_dict() and load_state_dict() to
...
@@ -216,37 +282,40 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -216,37 +282,40 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def
zero_grad
(
self
,
set_to_none
=
True
):
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
"""We only need to zero the model related parameters, i.e.,
f
p
16_groups & fp32_from_fp32_groups."""
f
loat
16_groups & fp32_from_fp32_groups."""
for
group
in
self
.
f
p
16_groups
:
for
group
in
self
.
f
loat
16_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
_zero_grad_group_helper
(
group
,
set_to_none
)
for
group
in
self
.
fp32_from_fp32_groups
:
for
group
in
self
.
fp32_from_fp32_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
_zero_grad_group_helper
(
group
,
set_to_none
)
def
get_loss_scale
(
self
):
def
get_loss_scale
(
self
):
if
self
.
grad_scaler
is
None
:
return
self
.
_scale_one
return
self
.
grad_scaler
.
scale
return
self
.
grad_scaler
.
scale
def
_copy_model_grads_to_main_grads
(
self
):
def
_copy_model_grads_to_main_grads
(
self
):
# This only needs to be done for the fp16 group.
# This only needs to be done for the float16 group.
model_grads
=
[]
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
main_grads
=
[]
self
.
fp32_from_float16_groups
):
for
model_group
,
main_group
in
zip
(
self
.
fp16_groups
,
self
.
fp32_from_fp16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
if
self
.
params_have_main_grad
:
main_param
.
grad
=
model_param
.
main_grad
.
float
()
else
:
if
model_param
.
grad
is
not
None
:
if
model_param
.
grad
is
not
None
:
if
main_param
.
grad
is
None
:
main_param
.
grad
=
model_param
.
grad
.
float
()
main_param
.
grad
=
to
rch
.
empty_like
(
main_param
)
# For fp32 grads, we need to reset the
grad
s
to
main grad.
model_grads
.
append
(
model_param
.
grad
.
data
)
if
self
.
params_have_main_grad
:
main_grads
.
append
(
main_param
.
grad
.
data
)
for
model_group
in
self
.
fp32_from_fp32_groups
:
_multi_tensor_copy_this_to_that
(
this
=
model_grads
,
that
=
main_grads
,
for
model_param
in
model_group
:
overflow_buf
=
self
.
_dummy_overflow_buf
)
model_param
.
grad
=
model_param
.
main_grad
def
_unscale_main_grads_and_check_for_nan
(
self
):
def
_unscale_main_grads_and_check_for_nan
(
self
):
main_grads
=
[]
main_grads
=
[]
# fp32 params fromm f
p
16 ones.
# fp32 params fromm f
loat
16 ones.
for
main_group
in
self
.
fp32_from_f
p
16_groups
:
for
main_group
in
self
.
fp32_from_f
loat
16_groups
:
for
main_param
in
main_group
:
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
main_grads
.
append
(
main_param
.
grad
.
data
)
...
@@ -270,11 +339,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -270,11 +339,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return
found_inf_flag
return
found_inf_flag
def
_get_model_and_main_params_data_f
p
16
(
self
):
def
_get_model_and_main_params_data_f
loat
16
(
self
):
model_data
=
[]
model_data
=
[]
main_data
=
[]
main_data
=
[]
for
model_group
,
main_group
in
zip
(
self
.
f
p
16_groups
,
for
model_group
,
main_group
in
zip
(
self
.
f
loat
16_groups
,
self
.
fp32_from_f
p
16_groups
):
self
.
fp32_from_f
loat
16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_data
.
append
(
model_param
.
data
)
model_data
.
append
(
model_param
.
data
)
main_data
.
append
(
main_param
.
data
)
main_data
.
append
(
main_param
.
data
)
...
@@ -282,15 +351,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -282,15 +351,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def
_copy_main_params_to_model_params
(
self
):
def
_copy_main_params_to_model_params
(
self
):
# Only needed for the f
p
16 params.
# Only needed for the f
loat
16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_f
p
16
()
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_f
loat
16
()
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
overflow_buf
=
self
.
_dummy_overflow_buf
)
def
_copy_model_params_to_main_params
(
self
):
def
_copy_model_params_to_main_params
(
self
):
# Only needed for the f
p
16 params.
# Only needed for the f
loat
16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_f
p
16
()
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_f
loat
16
()
_multi_tensor_copy_this_to_that
(
this
=
model_data
,
that
=
main_data
,
_multi_tensor_copy_this_to_that
(
this
=
model_data
,
that
=
main_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
overflow_buf
=
self
.
_dummy_overflow_buf
)
...
@@ -298,6 +367,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -298,6 +367,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def
reload_model_params
(
self
):
def
reload_model_params
(
self
):
self
.
_copy_model_params_to_main_params
()
self
.
_copy_model_params_to_main_params
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
step
(
self
):
def
step
(
self
):
...
@@ -308,6 +378,10 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -308,6 +378,10 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
self
.
_copy_model_grads_to_main_grads
()
self
.
_copy_model_grads_to_main_grads
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if
self
.
grad_scaler
:
# Unscale and check for inf/nan.
# Unscale and check for inf/nan.
timers
(
'optimizer-unscale-and-check-inf'
).
start
()
timers
(
'optimizer-unscale-and-check-inf'
).
start
()
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
...
@@ -329,7 +403,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -329,7 +403,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
timers
(
'optimizer-clip-main-grad'
).
stop
()
timers
(
'optimizer-clip-main-grad'
).
stop
()
# count the zeros in the grads
# count the zeros in the grads
num_zeros_in_grad
=
self
.
count_zeros
()
if
self
.
log_num_zeros_in_grad
else
None
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
# Step the optimizer.
# Step the optimizer.
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
...
@@ -346,8 +421,9 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -346,8 +421,9 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def
state_dict
(
self
):
def
state_dict
(
self
):
state_dict
=
{}
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
if
self
.
grad_scaler
:
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'fp32_from_fp16_params'
]
=
self
.
fp32_from_f
p
16_groups
state_dict
[
'fp32_from_fp16_params'
]
=
self
.
fp32_from_f
loat
16_groups
return
state_dict
return
state_dict
...
@@ -365,15 +441,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -365,15 +441,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
'load grad scaler ...'
)
'load grad scaler ...'
)
else
:
else
:
if
self
.
grad_scaler
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
else
:
print_rank_0
(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
# Copy data for the main params.
# Copy data for the main params.
fp32_from_f
p
16_params_key
=
'fp32_from_fp16_params'
fp32_from_f
loat
16_params_key
=
'fp32_from_fp16_params'
if
fp32_from_f
p
16_params_key
not
in
state_dict
:
if
fp32_from_f
loat
16_params_key
not
in
state_dict
:
fp32_from_f
p
16_params_key
=
'fp32_from_fp16'
fp32_from_f
loat
16_params_key
=
'fp32_from_fp16'
for
current_group
,
saved_group
in
zip
(
for
current_group
,
saved_group
in
zip
(
self
.
fp32_from_f
p
16_groups
,
self
.
fp32_from_f
loat
16_groups
,
state_dict
[
fp32_from_f
p
16_params_key
]):
state_dict
[
fp32_from_f
loat
16_params_key
]):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
current_param
.
data
.
copy_
(
saved_param
.
data
)
...
@@ -381,11 +462,14 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -381,11 +462,14 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
class
FP32Optimizer
(
MegatronOptimizer
):
class
FP32Optimizer
(
MegatronOptimizer
):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
):
super
(
FP32Optimizer
,
self
).
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
)
super
(
FP32Optimizer
,
self
).
__init__
(
optimizer
)
self
.
clip_grad
=
clip_grad
self
.
log_num_zeros_in_grad
=
log_num_zeros_in_grad
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
1.0
])
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
1.0
])
...
@@ -405,13 +489,20 @@ class FP32Optimizer(MegatronOptimizer):
...
@@ -405,13 +489,20 @@ class FP32Optimizer(MegatronOptimizer):
"""Clip gradients (if needed) and step the base optimizer.
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
Always return successful since there is no overflow."""
# Copy main_grads to grads.
if
self
.
params_have_main_grad
:
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
param
.
grad
=
param
.
main_grad
# Clip gradients.
# Clip gradients.
grad_norm
=
None
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
# count the zeros in the grads
# count the zeros in the grads
num_zeros_in_grad
=
self
.
count_zeros
()
if
self
.
log_num_zeros_in_grad
else
None
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
# Update parameters.
# Update parameters.
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
...
...
megatron/training.py
View file @
1a2cb60c
...
@@ -37,9 +37,8 @@ from megatron import print_rank_0
...
@@ -37,9 +37,8 @@ from megatron import print_rank_0
from
megatron
import
print_rank_last
from
megatron
import
print_rank_last
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.model
import
F
P
16Module
from
megatron.model
import
F
loat
16Module
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.learning_rates
import
AnnealingLR
from
megatron.learning_rates
import
AnnealingLR
...
@@ -54,6 +53,7 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving
...
@@ -54,6 +53,7 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
def
print_datetime
(
string
):
def
print_datetime
(
string
):
"""Note that this call will sync across all ranks."""
"""Note that this call will sync across all ranks."""
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
...
@@ -222,8 +222,18 @@ def get_model(model_provider_func):
...
@@ -222,8 +222,18 @@ def get_model(model_provider_func):
model_module
.
cuda
(
torch
.
cuda
.
current_device
())
model_module
.
cuda
(
torch
.
cuda
.
current_device
())
# Fp16 conversion.
# Fp16 conversion.
if
args
.
fp16
:
if
args
.
fp16
or
args
.
bf16
:
model
=
[
FP16Module
(
model_module
)
for
model_module
in
model
]
model
=
[
Float16Module
(
model_module
,
args
)
for
model_module
in
model
]
# For now, the layer norm does not support input float32 and outut bf16.
# For this, we move layernorm parameters to fp32 and cast output of the
# layernorm operation back to bf16.
if
args
.
bf16
and
args
.
fp32_residual_connection
:
from
megatron.model
import
import_layernorm
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
,
args
.
bf16
)
for
model_
in
model
:
for
module_
in
model_
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
module_
.
float
()
if
args
.
DDP_impl
==
'torch'
:
if
args
.
DDP_impl
==
'torch'
:
i
=
torch
.
cuda
.
current_device
()
i
=
torch
.
cuda
.
current_device
()
...
@@ -231,8 +241,12 @@ def get_model(model_provider_func):
...
@@ -231,8 +241,12 @@ def get_model(model_provider_func):
process_group
=
mpu
.
get_data_parallel_group
())
process_group
=
mpu
.
get_data_parallel_group
())
for
model_module
in
model
]
for
model_module
in
model
]
return
model
return
model
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
model
=
[
LocalDDP
(
model_module
)
for
model_module
in
model
]
model
=
[
LocalDDP
(
model_module
,
args
.
accumulate_allreduce_grads_in_fp32
,
args
.
use_contiguous_buffers_in_ddp
)
for
model_module
in
model
]
return
model
return
model
raise
NotImplementedError
(
'Unknown DDP implementation specified: {}. '
raise
NotImplementedError
(
'Unknown DDP implementation specified: {}. '
...
@@ -289,7 +303,7 @@ def setup_model_and_optimizer(model_provider_func):
...
@@ -289,7 +303,7 @@ def setup_model_and_optimizer(model_provider_func):
model
=
get_model
(
model_provider_func
)
model
=
get_model
(
model_provider_func
)
unwrapped_model
=
unwrap_model
(
model
,
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
F
P
16Module
))
(
torchDDP
,
LocalDDP
,
F
loat
16Module
))
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
...
@@ -308,9 +322,7 @@ def setup_model_and_optimizer(model_provider_func):
...
@@ -308,9 +322,7 @@ def setup_model_and_optimizer(model_provider_func):
args
.
iteration
=
0
args
.
iteration
=
0
# We only support local DDP with multiple micro-batches.
# We only support local DDP with multiple micro-batches.
if
len
(
model
)
>
1
:
if
len
(
model
)
>
1
or
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
assert
args
.
DDP_impl
==
'local'
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
assert
args
.
DDP_impl
==
'local'
assert
args
.
DDP_impl
==
'local'
# get model without FP16 and/or TorchDDP wrappers
# get model without FP16 and/or TorchDDP wrappers
...
@@ -331,6 +343,10 @@ def train_step(forward_step_func, data_iterator,
...
@@ -331,6 +343,10 @@ def train_step(forward_step_func, data_iterator,
timers
=
get_timers
()
timers
=
get_timers
()
# Set grad to zero.
# Set grad to zero.
if
args
.
DDP_impl
==
'local'
and
args
.
use_contiguous_buffers_in_ddp
:
for
partition
in
model
:
partition
.
zero_grad_buffer
()
else
:
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
...
@@ -351,8 +367,7 @@ def train_step(forward_step_func, data_iterator,
...
@@ -351,8 +367,7 @@ def train_step(forward_step_func, data_iterator,
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
timers
(
'backward-params-all-reduce'
).
start
()
timers
(
'backward-params-all-reduce'
).
start
()
for
model_module
in
model
:
for
model_module
in
model
:
model_module
.
allreduce_params
(
reduce_after
=
False
,
model_module
.
allreduce_gradients
()
fp32_allreduce
=
args
.
fp32_allreduce
)
timers
(
'backward-params-all-reduce'
).
stop
()
timers
(
'backward-params-all-reduce'
).
stop
()
# All-reduce word_embeddings' grad across first and last stages to ensure
# All-reduce word_embeddings' grad across first and last stages to ensure
...
@@ -368,12 +383,15 @@ def train_step(forward_step_func, data_iterator,
...
@@ -368,12 +383,15 @@ def train_step(forward_step_func, data_iterator,
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
unwrapped_model
=
model
[
-
1
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
F
P
16Module
))
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
F
loat
16Module
))
if
unwrapped_model
.
share_word_embeddings
:
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
torch
.
distributed
.
all_reduce
(
word_embeddings_weight
.
grad
,
if
args
.
DDP_impl
==
'local'
:
group
=
mpu
.
get_embedding_group
())
grad
=
word_embeddings_weight
.
main_grad
else
:
grad
=
word_embeddings_weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
timers
(
'backward-embedding-all-reduce'
).
stop
()
timers
(
'backward-embedding-all-reduce'
).
stop
()
# Update parameters.
# Update parameters.
...
...
megatron/utils.py
View file @
1a2cb60c
...
@@ -48,12 +48,19 @@ def unwrap_model(model, module_instances=(torchDDP)):
...
@@ -48,12 +48,19 @@ def unwrap_model(model, module_instances=(torchDDP)):
def
calc_params_l2_norm
(
model
):
def
calc_params_l2_norm
(
model
):
"""Calculate l2 norm of parameters """
"""Calculate l2 norm of parameters """
args
=
get_args
()
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
# Remove duplicate params.
# Remove duplicate params.
params_data
=
[]
params_data
=
[]
for
param
in
model
.
parameters
():
for
model_
in
model
:
for
param
in
model_
.
parameters
():
is_not_shared
=
param_is_not_shared
(
param
)
is_not_shared
=
param_is_not_shared
(
param
)
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
if
is_not_shared
and
is_not_tp_duplicate
:
if
is_not_shared
and
is_not_tp_duplicate
:
if
args
.
bf16
:
params_data
.
append
(
param
.
data
.
float
())
else
:
params_data
.
append
(
param
.
data
)
params_data
.
append
(
param
.
data
)
# Calculate norm
# Calculate norm
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment