Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
aed2f75e
Commit
aed2f75e
authored
Apr 11, 2021
by
Jared Casper
Browse files
Merge branch 'main' into github-main
parents
8aa4619f
f32a638d
Changes
96
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1427 additions
and
971 deletions
+1427
-971
megatron/model/classification.py
megatron/model/classification.py
+32
-72
megatron/model/distributed.py
megatron/model/distributed.py
+178
-72
megatron/model/enums.py
megatron/model/enums.py
+28
-0
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+28
-117
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+86
-49
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+28
-92
megatron/model/language_model.py
megatron/model/language_model.py
+153
-189
megatron/model/module.py
megatron/model/module.py
+60
-28
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+36
-76
megatron/model/realm_model.py
megatron/model/realm_model.py
+3
-2
megatron/model/transformer.py
megatron/model/transformer.py
+231
-144
megatron/model/utils.py
megatron/model/utils.py
+5
-0
megatron/model/vit_model.py
megatron/model/vit_model.py
+210
-0
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+5
-1
megatron/mpu/data.py
megatron/mpu/data.py
+1
-1
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+46
-3
megatron/mpu/layers.py
megatron/mpu/layers.py
+7
-3
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+63
-33
megatron/optimizer/clip_grads.py
megatron/optimizer/clip_grads.py
+32
-3
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+195
-86
No files found.
megatron/model/classification.py
View file @
aed2f75e
...
...
@@ -19,7 +19,8 @@ import torch
from
megatron
import
get_args
,
print_rank_last
from
megatron
import
mpu
from
megatron.model.bert_model
import
bert_attention_mask_func
,
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.bert_model
import
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.language_model
import
get_language_model
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
...
...
@@ -27,46 +28,57 @@ from megatron.model.utils import scaled_init_method_normal
from
.module
import
MegatronModule
class
Classification
Base
(
MegatronModule
):
class
Classification
(
MegatronModule
):
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
):
super
(
ClassificationBase
,
self
).
__init__
(
share_word_embeddings
=
False
)
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
,
pre_process
=
True
,
post_process
=
True
):
super
(
Classification
,
self
).
__init__
(
share_word_embeddings
=
False
)
args
=
get_args
()
self
.
num_classes
=
num_classes
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
# Multi-choice head.
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
self
.
classification_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
classification_head
=
get_linear_layer
(
args
.
hidden_size
,
self
.
num_classes
,
init_method
)
self
.
_classification_head_key
=
'classification_head'
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
model_input
,
attention_mask
,
tokentype_ids
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
input_ids
=
model_input
position_ids
=
bert_position_ids
(
input_ids
)
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
kwargs
=
{}
if
mpu
.
is_pipeline_first_stage
():
input_ids
=
model_input
position_ids
=
bert_position_ids
(
input_ids
)
args
=
[
input_ids
,
position_ids
,
extended_attention_mask
]
kwargs
[
'tokentype_ids'
]
=
tokentype_ids
else
:
args
=
[
model_input
,
extended_attention_mask
]
lm_output
=
self
.
language_model
(
*
args
,
**
kwargs
)
if
mpu
.
is_pipeline_last_stage
():
if
self
.
post_process
:
_
,
pooled_output
=
lm_output
classification_output
=
self
.
classification_dropout
(
pooled_output
)
classification_logits
=
self
.
classification_head
(
classification_output
)
...
...
@@ -86,7 +98,7 @@ class ClassificationBase(MegatronModule):
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
state_dict_
[
self
.
_classification_head_key
]
\
=
self
.
classification_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
...
...
@@ -97,7 +109,7 @@ class ClassificationBase(MegatronModule):
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
if
self
.
_classification_head_key
in
state_dict
:
self
.
classification_head
.
load_state_dict
(
state_dict
[
self
.
_classification_head_key
],
strict
=
strict
)
...
...
@@ -105,55 +117,3 @@ class ClassificationBase(MegatronModule):
print_rank_last
(
'***WARNING*** could not find {} in the checkpoint, '
'initializing to random'
.
format
(
self
.
_classification_head_key
))
class
Classification
(
ClassificationBase
):
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
):
super
(
Classification
,
self
).
__init__
(
num_classes
,
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
return
super
(
Classification
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
)
class
ClassificationFirstStage
(
ClassificationBase
):
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
):
super
(
ClassificationFirstStage
,
self
).
__init__
(
num_classes
,
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
return
super
(
ClassificationFirstStage
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
)
class
ClassificationIntermediateStage
(
ClassificationBase
):
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
):
super
(
ClassificationIntermediateStage
,
self
).
__init__
(
num_classes
,
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
):
return
super
(
ClassificationIntermediateStage
,
self
).
forward
(
hidden_state
,
attention_mask
)
class
ClassificationLastStage
(
ClassificationBase
):
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
):
super
(
ClassificationLastStage
,
self
).
__init__
(
num_classes
,
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
):
return
super
(
ClassificationLastStage
,
self
).
forward
(
hidden_state
,
attention_mask
)
megatron/model/distributed.py
View file @
aed2f75e
...
...
@@ -13,100 +13,206 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
abc
import
ABC
from
abc
import
abstractmethod
import
torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
import
torch.distributed
as
dist
from
torch.nn.modules
import
Module
from
torch.autograd
import
Variable
from
megatron
import
get_args
from
megatron
import
mpu
from
.module
import
MegatronModule
class
DistributedDataParallel
(
MegatronModule
):
def
__init__
(
self
,
module
):
super
(
DistributedDataParallel
,
self
).
__init__
()
self
.
warn_on_half
=
True
if
dist
.
_backend
==
dist
.
dist_backend
.
GLOO
else
False
class
MemoryBuffer
:
def
__init__
(
self
,
numel
,
dtype
):
self
.
numel
=
numel
self
.
dtype
=
dtype
self
.
data
=
torch
.
zeros
(
self
.
numel
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
def
zero
(
self
):
"""Reset the buffer to zero."""
self
.
data
.
zero_
()
def
get
(
self
,
shape
,
start_index
):
"""Return a tensor with the input `shape` as a view into the
1-D data starting at `start_index`."""
end_index
=
start_index
+
shape
.
numel
()
assert
end_index
<=
self
.
numel
,
\
'requested tensor is out of the buffer range.'
buffer_tensor
=
self
.
data
[
start_index
:
end_index
]
buffer_tensor
=
buffer_tensor
.
view
(
shape
)
return
buffer_tensor
class
DistributedDataParallelBase
(
MegatronModule
,
ABC
):
"""Abstract class for DDP."""
def
__init__
(
self
,
module
):
super
(
DistributedDataParallelBase
,
self
).
__init__
()
# Keep a pointer to the model.
self
.
module
=
module
self
.
data_parallel_group
=
mpu
.
get_data_parallel_group
()
def
allreduce_params
(
reduce_after
=
True
,
no_scale
=
False
,
fp32_allreduce
=
False
):
if
(
self
.
needs_reduction
):
self
.
needs_reduction
=
False
buckets
=
{}
for
name
,
param
in
self
.
module
.
named_parameters
():
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
tp
=
(
param
.
data
.
type
())
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
param
)
if
self
.
warn_on_half
:
if
torch
.
cuda
.
HalfTensor
in
buckets
:
print
(
"WARNING: gloo dist backend for half parameters may be extremely slow."
+
" It is recommended to use the NCCL backend in this case."
)
self
.
warn_on_half
=
False
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
if
fp32_allreduce
:
coalesced
=
coalesced
.
float
()
if
not
no_scale
and
not
reduce_after
:
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
dist
.
all_reduce
(
coalesced
,
group
=
self
.
data_parallel_group
)
torch
.
cuda
.
synchronize
()
if
not
no_scale
and
reduce_after
:
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
self
.
hook_handles
=
[]
self
.
hooks
=
[]
for
param
in
list
(
self
.
module
.
parameters
()):
def
allreduce_hook
(
*
unused
):
Variable
.
_execution_engine
.
queue_callback
(
allreduce_params
)
# handle = param.register_hook(allreduce_hook)
# self.hooks.append(allreduce_hook)
# self.hook_handles.append(handle)
self
.
allreduce_params
=
allreduce_params
@
abstractmethod
def
allreduce_gradients
(
self
):
pass
def
forward
(
self
,
*
inputs
,
**
kwargs
):
self
.
needs_reduction
=
True
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
#[h.remove() for h in self.hook_handles]
sd
=
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
# for handle, hook in zip(self.hook_handles, self.hooks):
# d = handle.hooks_dict_ref()
# d[handle.id] = hook
return
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
sd
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
return
self
.
module
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
'''
def _sync_buffers(self):
buffers = list(self.module._all_buffers())
if len(buffers) > 0:
# cross-node buffer sync
flat_buffers = _flatten_dense_tensors(buffers)
dist.broadcast(flat_buffers, 0)
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
buf.copy_(synced)
def train(self, mode=True):
# Clear NCCL communicator and CUDA event cache of the default group ID,
# These cache will be recreated at the later call. This is currently a
# work-around for a potential NCCL deadlock.
if dist._backend == dist.dist_backend.NCCL:
dist._clear_group_cache()
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
'''
class
DistributedDataParallel
(
DistributedDataParallelBase
):
"""DDP with contiguous buffers options to storre and accumulate gradients.
This class:
- has the potential to reduce memory fragmentation.
- provides the option to do the gradient accumulation
in a type other than the params type (for example fp32)
Arguments:
module: input model.
accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
and the gradient all-reduce all in in float32. If this option is
true, we require `use_contiguous_buffers` to be true too.
use_contiguous_buffers: if true, use a contiguous buffer to store the
gradients.
"""
def
__init__
(
self
,
module
,
accumulate_allreduce_grads_in_fp32
,
use_contiguous_buffers
):
super
(
DistributedDataParallel
,
self
).
__init__
(
module
)
self
.
accumulate_allreduce_grads_in_fp32
\
=
accumulate_allreduce_grads_in_fp32
self
.
use_contiguous_buffers
=
use_contiguous_buffers
# If we are using fp32-accumulate-allreduce explicitly
# this means we need main grads in a continous buffer.
if
self
.
accumulate_allreduce_grads_in_fp32
:
assert
self
.
use_contiguous_buffers
# ===================================
# Rest of this part applies only to
# the case we use continuous buffers.
# ===================================
self
.
_grad_buffers
=
None
if
self
.
use_contiguous_buffers
:
self
.
_grad_buffers
=
{}
# Simple function to define buffer type.
def
_get_buffer_type
(
param
):
return
torch
.
float
if
\
self
.
accumulate_allreduce_grads_in_fp32
else
param
.
dtype
# First calculate total number of elements per type.
type_num_elements
=
{}
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
dtype
=
_get_buffer_type
(
param
)
type_num_elements
[
dtype
]
=
type_num_elements
.
get
(
dtype
,
0
)
\
+
param
.
data
.
nelement
()
# Allocate the buffer.
for
dtype
,
num_elements
in
type_num_elements
.
items
():
self
.
_grad_buffers
[
dtype
]
=
MemoryBuffer
(
num_elements
,
dtype
)
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
dtype
=
_get_buffer_type
(
param
)
type_num_elements
[
dtype
]
-=
param
.
data
.
nelement
()
param
.
main_grad
=
self
.
_grad_buffers
[
dtype
].
get
(
param
.
data
.
shape
,
type_num_elements
[
dtype
])
# Backward hook.
# Accumalation function for the gradients. We need
# to store them so they don't go out of scope.
self
.
grad_accs
=
[]
# Loop over all the parameters in the model.
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
# Expand so we get access to grad_fn.
param_tmp
=
param
.
expand_as
(
param
)
# Get the gradient accumulator functtion.
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
grad_acc
.
register_hook
(
self
.
_make_param_hook
(
param
))
self
.
grad_accs
.
append
(
grad_acc
)
def
_make_param_hook
(
self
,
param
):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def
param_hook
(
*
unused
):
# Add the gradient to the buffer.
if
param
.
grad
.
data
is
not
None
:
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
# Now we can deallocate grad memory.
param
.
grad
=
None
return
param_hook
def
zero_grad_buffer
(
self
):
"""Set the grad buffer data to zero. Needs to be called at the
begining of each iteration."""
assert
self
.
_grad_buffers
is
not
None
,
'buffers are not initialized.'
for
_
,
buffer_
in
self
.
_grad_buffers
.
items
():
buffer_
.
zero
()
def
allreduce_gradients
(
self
):
"""Reduce gradients across data parallel ranks."""
# If we have buffers, simply reduce the data in the buffer.
if
self
.
_grad_buffers
is
not
None
:
for
_
,
buffer_
in
self
.
_grad_buffers
.
items
():
buffer_
.
data
/=
mpu
.
get_data_parallel_world_size
()
torch
.
distributed
.
all_reduce
(
buffer_
.
data
,
group
=
mpu
.
get_data_parallel_group
())
else
:
# Otherwise, bucketize and all-reduce
buckets
=
{}
# Pack the buckets.
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
tp
=
param
.
data
.
type
()
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
param
)
param
.
main_grad
=
param
.
grad
# For each bucket, all-reduce and copy all-reduced grads.
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
/=
mpu
.
get_data_parallel_world_size
()
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
mpu
.
get_data_parallel_group
())
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
megatron/model/enums.py
0 → 100644
View file @
aed2f75e
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
enum
class
LayerType
(
enum
.
Enum
):
encoder
=
1
decoder
=
2
class
AttnType
(
enum
.
Enum
):
self_attn
=
1
cross_attn
=
2
class
AttnMaskType
(
enum
.
Enum
):
padding
=
1
causal
=
2
megatron/model/fused_layer_norm.py
View file @
aed2f75e
...
...
@@ -15,29 +15,23 @@
"""This code is copied fron NVIDIA apex:
https://github.com/NVIDIA/apex
with
minor
changes. """
with
some
changes. """
import
math
import
torch
import
numbers
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn
import
init
from
torch.nn
import
functional
as
F
import
importlib
global
fused_layer_norm_cuda
fused_layer_norm_cuda
=
None
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
None
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
normalized_shape
,
eps
):
global
fused_mix_prec_layer_norm_cuda
if
fused_mix_prec_layer_norm_cuda
is
None
:
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
ctx
.
normalized_shape
=
normalized_shape
ctx
.
eps
=
eps
input_
=
input
.
contiguous
()
...
...
@@ -46,134 +40,51 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
output
,
mean
,
invvar
=
fused_mix_prec_layer_norm_cuda
.
forward_affine
(
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
ctx
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input_
,
weight_
,
bias_
,
mean
,
invvar
=
ctx
.
saved_tensors
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
,
grad_weight
,
grad_bias
=
fused_mix_prec_layer_norm_cuda
.
backward_affine
(
grad_input
,
grad_weight
,
grad_bias
\
=
fused_mix_prec_layer_norm_cuda
.
backward_affine
(
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
class
FusedLayerNormFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
normalized_shape
,
eps
):
global
fused_layer_norm_cuda
if
fused_layer_norm_cuda
is
None
:
fused_layer_norm_cuda
=
importlib
.
import_module
(
"fused_layer_norm_cuda"
)
ctx
.
normalized_shape
=
normalized_shape
ctx
.
eps
=
eps
input_
=
input
.
contiguous
()
output
,
mean
,
invvar
=
fused_layer_norm_cuda
.
forward
(
input_
,
ctx
.
normalized_shape
,
ctx
.
eps
)
ctx
.
save_for_backward
(
input_
,
mean
,
invvar
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input_
,
mean
,
invvar
=
ctx
.
saved_tensors
grad_input
=
None
grad_input
=
fused_layer_norm_cuda
.
backward
(
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
ctx
.
normalized_shape
,
ctx
.
eps
)
return
grad_input
,
None
,
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
def
fused_layer_norm_affine
(
input
,
normalized_shape
,
weight
,
bias
,
eps
=
1e-6
):
return
FusedLayerNormAffineFunction
.
apply
(
input
,
weight
,
bias
,
normalized_shape
,
eps
)
def
fused_layer_norm
(
input
,
normalized_shape
,
eps
=
1e-6
):
return
FusedLayerNormFunction
.
apply
(
input
,
normalized_shape
,
eps
)
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
r
"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ .
Currently only runs on cuda() tensors.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1]
\times \ldots \times \text{normalized}\_\text{shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = apex.normalization.FusedLayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = apex.normalization.FusedLayerNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
elementwise_affine
=
True
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
):
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
global
fused_layer_norm_cuda
fused_layer_norm_cuda
=
importlib
.
import_module
(
"fused_layer_norm_cuda"
)
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
normalized_shape
=
(
normalized_shape
,)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
self
.
eps
=
eps
self
.
elementwise_affine
=
elementwise_affine
if
self
.
elementwise_affine
:
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
else
:
self
.
register_parameter
(
'weight'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
reset_parameters
()
def
reset_parameters
(
self
):
if
self
.
elementwise_affine
:
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
if
not
input
.
is_cuda
:
return
F
.
layer_norm
(
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
if
self
.
elementwise_affine
:
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
return
FusedLayerNormFunction
.
apply
(
input
,
self
.
normalized_shape
,
self
.
eps
)
def
extra_repr
(
self
):
return
'{normalized_shape}, eps={eps}, '
\
'elementwise_affine={elementwise_affine}'
.
format
(
**
self
.
__dict__
)
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
megatron/model/fused_softmax.py
View file @
aed2f75e
...
...
@@ -14,114 +14,151 @@
# limitations under the License.
import
torch
from
megatron.model.enums
import
AttnMaskType
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
)
:
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
import
scaled_upper_triang_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
\
scaled_upper_triang_masked_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
])
softmax_results
=
scaled_upper_triang_masked_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
]
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_upper_triang_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
\
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
input_grads
=
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
)
:
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
mask
,
scale
):
import
scaled_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
\
scaled_masked_softmax_cuda
.
forward
(
inputs
,
mask
,
scale_t
[
0
])
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
mask
,
scale_t
[
0
]
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
\
scaled_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
input_grads
=
scaled_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking.
(used in gpt family networks)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
attn_mask_type: attention mask type (pad or causal)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def
__init__
(
self
,
input_in_fp16
,
upper_triang_mask_fusion
,
general_mask_fusion
,
mask_func
,
softmax_in_fp32
,
scale
):
def
__init__
(
self
,
input_in_fp16
,
input_in_bf16
,
attn_mask_type
,
scaled_masked_softmax_fusion
,
mask_func
,
softmax_in_fp32
,
scale
,
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
upper_triang_mask_fusion
=
upper_triang_mask_fusion
self
.
general_mask_fusion
=
general_mask_fusion
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
\
'both fp16 and bf16 flags cannot be active at the same time.'
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
assert
self
.
scale
is
None
or
softmax_in_fp32
,
\
'softmax should be in fp32 when scaled'
assert
(
self
.
scale
is
None
or
softmax_in_fp32
),
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
# [b, np, s, s]
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
data_size
=
input
.
size
()
assert
input
.
dim
()
==
4
query_seq_len
=
data_size
[
-
2
]
key_seq_len
=
data_size
[
-
1
]
attn_batch_size
=
data_size
[
0
]
*
data_size
[
1
]
# constraints on various tensor dimensions to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
key_seq_len
>
16
and
key_seq_len
<=
2048
and
\
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# invoke custom kernel
if
self
.
input_in_fp16
and
data_size
[
-
1
]
<=
2048
and
\
(
self
.
upper_triang_mask_fusion
or
self
.
general_mask_fusion
)
and
\
input
.
size
()[
2
]
==
input
.
size
()[
3
]:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
upper_triang_mask_fusion
:
input
=
input
.
view
(
-
1
,
data_size
[
2
],
data_size
[
3
])
if
self
.
input_in_float16
and
mask
is
not
None
and
\
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
query_seq_len
==
key_seq_len
,
\
"causal mask is only for self attention"
input
=
input
.
view
(
-
1
,
query_seq_len
,
key_seq_len
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
probs
.
view
(
*
data_size
)
else
:
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
if
self
.
input_in_f
p
16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_f
loat
16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
mask_output
=
self
.
mask_func
(
input
,
mask
)
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
probs
=
probs
.
half
()
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
megatron/model/gpt_model.py
View file @
aed2f75e
...
...
@@ -21,17 +21,13 @@ from megatron import get_args
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.enums
import
AttnMaskType
from
.language_model
import
parallel_lm_logits
from
.language_model
import
get_language_model
from
.utils
import
init_method_normal
from
.utils
import
scaled_init_method_normal
def
gpt_attention_mask_func
(
attention_scores
,
ltor_mask
):
attention_scores
.
masked_fill_
(
ltor_mask
,
-
10000.0
)
return
attention_scores
def
post_language_model_processing
(
lm_output
,
labels
,
logit_weights
,
get_key_value
,
parallel_output
,
forward_method_parallel_output
,
...
...
@@ -61,40 +57,50 @@ def post_language_model_processing(lm_output, labels, logit_weights,
return
loss
class
GPTModel
Base
(
MegatronModule
):
class
GPTModel
(
MegatronModule
):
"""GPT-2 Language model."""
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
):
super
(
GPTModelBase
,
self
).
__init__
()
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
super
(
GPTModel
,
self
).
__init__
()
args
=
get_args
()
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
attention_mask_func
=
gpt_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
encoder_attn_mask_type
=
AttnMaskType
.
causal
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
initialize_word_embeddings
(
init_method_normal
)
def
forward
(
self
,
gpt_model_input
,
attention_mask
,
labels
=
None
,
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
kwargs
=
{
'layer_past'
:
layer_past
,
'get_key_value'
:
get_key_value
}
if
mpu
.
is_pipeline_first_stage
():
(
input_ids
,
position_ids
)
=
gpt_model_input
args
=
[
input_ids
,
position_ids
,
attention_mask
]
kwargs
[
'tokentype_ids'
]
=
tokentype_ids
else
:
args
=
[
gpt_model_input
,
attention_mask
]
lm_output
=
self
.
language_model
(
*
args
,
**
kwargs
)
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
return
post_language_model_processing
(
lm_output
,
labels
,
self
.
word_embeddings_weight
(),
...
...
@@ -113,7 +119,7 @@ class GPTModelBase(MegatronModule):
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
# Save word_embeddings.
if
mpu
.
is_pipeline_last_stage
()
and
not
mpu
.
is_pipeline_first_stage
()
:
if
self
.
post_process
and
not
self
.
pre_process
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
...
...
@@ -122,79 +128,9 @@ class GPTModelBase(MegatronModule):
"""Customized load."""
# Load word_embeddings.
if
mpu
.
is_pipeline_last_stage
()
and
not
mpu
.
is_pipeline_first_stage
()
:
if
self
.
post_process
and
not
self
.
pre_process
:
self
.
word_embeddings
.
load_state_dict
(
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
if
self
.
_language_model_key
in
state_dict
:
state_dict
=
state_dict
[
self
.
_language_model_key
]
self
.
language_model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
class
GPTModel
(
GPTModelBase
):
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
):
super
(
GPTModel
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
,
parallel_output
=
parallel_output
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
return
super
(
GPTModel
,
self
).
forward
(
(
input_ids
,
position_ids
),
attention_mask
,
labels
=
labels
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
forward_method_parallel_output
=
forward_method_parallel_output
)
class
GPTModelFirstStage
(
GPTModelBase
):
def
__init__
(
self
,
num_tokentypes
=
0
):
super
(
GPTModelFirstStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
return
super
(
GPTModelFirstStage
,
self
).
forward
(
(
input_ids
,
position_ids
),
attention_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
class
GPTModelIntermediateStage
(
GPTModelBase
):
def
__init__
(
self
,
num_tokentypes
=
0
):
super
(
GPTModelIntermediateStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
return
super
(
GPTModelIntermediateStage
,
self
).
forward
(
hidden_state
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
class
GPTModelLastStage
(
GPTModelBase
):
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
):
super
(
GPTModelLastStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
,
parallel_output
=
parallel_output
)
def
forward
(
self
,
hidden_state
,
attention_mask
,
labels
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
return
super
(
GPTModelLastStage
,
self
).
forward
(
hidden_state
,
attention_mask
,
labels
=
labels
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
forward_method_parallel_output
=
forward_method_parallel_output
)
megatron/model/language_model.py
View file @
aed2f75e
...
...
@@ -21,6 +21,7 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
mpu
from
.module
import
MegatronModule
from
megatron.model.enums
import
LayerType
,
AttnMaskType
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
,
scaled_init_method_normal
...
...
@@ -42,8 +43,11 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
return
mpu
.
gather_from_tensor_model_parallel_region
(
logits_parallel
)
def
get_language_model
(
attention_mask_func
,
num_tokentypes
,
add_pooler
,
init_method
=
None
,
scaled_init_method
=
None
):
def
get_language_model
(
num_tokentypes
,
add_pooler
,
encoder_attn_mask_type
,
init_method
=
None
,
scaled_init_method
=
None
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
pre_process
=
True
,
post_process
=
True
):
"""Build language model and return along with the key to save."""
args
=
get_args
()
...
...
@@ -51,27 +55,21 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method
=
init_method_normal
(
args
.
init_method_std
)
if
scaled_init_method
is
None
:
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
# Language model.
args
=
[
attention_mask_func
,
init_method
,
scaled_init_method
]
kwargs
=
{}
cls
=
None
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModel
kwargs
[
'num_tokentypes'
]
=
num_tokentypes
kwargs
[
'add_pooler'
]
=
add_pooler
elif
mpu
.
is_pipeline_first_stage
()
and
not
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModelFirstStage
kwargs
[
'num_tokentypes'
]
=
num_tokentypes
elif
not
mpu
.
is_pipeline_first_stage
()
and
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModelLastStage
kwargs
[
'add_pooler'
]
=
add_pooler
else
:
cls
=
TransformerLanguageModelIntermediateStage
# Language model.
language_model
=
cls
(
*
args
,
**
kwargs
)
language_model
=
TransformerLanguageModel
(
init_method
,
scaled_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
num_tokentypes
,
add_decoder
=
add_decoder
,
decoder_attn_mask_type
=
decoder_attn_mask_type
,
add_pooler
=
add_pooler
,
pre_process
=
pre_process
,
post_process
=
post_process
)
# key used for checkpoints.
language_model_key
=
'language_model'
...
...
@@ -257,17 +255,11 @@ class Embedding(MegatronModule):
'checkpoint but could not find it'
,
flush
=
True
)
class
TransformerLanguageModel
Base
(
MegatronModule
):
class
TransformerLanguageModel
(
MegatronModule
):
"""Transformer language model.
Arguments:
transformer_hparams: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
...
...
@@ -277,21 +269,30 @@ class TransformerLanguageModelBase(MegatronModule):
"""
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
0
,
add_pooler
=
False
):
super
(
TransformerLanguageModelBase
,
self
).
__init__
()
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
add_pooler
=
False
,
pre_process
=
True
,
post_process
=
True
):
super
(
TransformerLanguageModel
,
self
).
__init__
()
args
=
get_args
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
hidden_size
=
args
.
hidden_size
self
.
num_tokentypes
=
num_tokentypes
self
.
init_method
=
init_method
self
.
encoder_attn_mask_type
=
encoder_attn_mask_type
self
.
add_decoder
=
add_decoder
self
.
decoder_attn_mask_type
=
decoder_attn_mask_type
self
.
add_pooler
=
add_pooler
# Embeddings.
if
mpu
.
is_pipeline_first_stage
()
:
if
self
.
pre_process
:
self
.
embedding
=
Embedding
(
self
.
hidden_size
,
args
.
padded_vocab_size
,
args
.
max_position_embeddings
,
...
...
@@ -301,57 +302,109 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
_embedding_key
=
'embedding'
# Transformer.
self
.
transformer
=
ParallelTransformer
(
attention_mask_func
,
self
.
init_method
,
output_layer_init_method
)
self
.
_transformer_key
=
'transformer'
# Pooler.
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
'pooler'
def
forward
(
self
,
language_model_input
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
):
self
.
encoder
=
ParallelTransformer
(
self
.
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
self
.
encoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
_encoder_key
=
'encoder'
# Decoder
if
self
.
add_decoder
:
assert
args
.
pipeline_model_parallel_size
==
1
,
\
'pipeline parallelism is not supported in the presence of decoder'
self
.
decoder
=
ParallelTransformer
(
self
.
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
decoder
,
self_attn_mask_type
=
self
.
decoder_attn_mask_type
)
self
.
_decoder_key
=
'decoder'
if
self
.
post_process
:
# Pooler.
if
self
.
add_pooler
:
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
'pooler'
def
set_input_tensor
(
self
,
input_tensor
):
""" See megatron.model.transformer.set_input_tensor()"""
self
.
encoder
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
# Embeddings.
if
mpu
.
is_pipeline_first_stage
():
(
input_ids
,
position_ids
)
=
language_model_input
embedding_output
=
self
.
embedding
(
input_ids
,
position_ids
,
if
self
.
pre_process
:
embedding_output
=
self
.
embedding
(
enc_input_ids
,
enc_position_ids
,
tokentype_ids
=
tokentype_ids
)
transform
er_input
=
embedding_output
encod
er_input
=
embedding_output
else
:
transformer_input
=
language_model_input
# Transformer.
transformer_output
=
self
.
transformer
(
transformer_input
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
pooled_output
=
self
.
pooler
(
transformer_output
,
pooling_sequence_index
)
return
transformer_output
,
pooled_output
return
transformer_output
encoder_input
=
None
# encoder.
if
enc_hidden_states
is
None
:
encoder_output
=
self
.
encoder
(
encoder_input
,
enc_attn_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
else
:
encoder_output
=
enc_hidden_states
.
to
(
encoder_input
.
dtype
)
if
self
.
post_process
:
if
self
.
add_pooler
:
pooled_output
=
self
.
pooler
(
encoder_output
,
pooling_sequence_index
)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if
not
self
.
add_decoder
or
output_enc_hidden
:
if
self
.
add_pooler
and
self
.
post_process
:
return
encoder_output
,
pooled_output
else
:
return
encoder_output
# Decoder Embedding
dec_embedding_output
=
self
.
embedding
(
dec_input_ids
,
dec_position_ids
)
# decoder
decoder_output
=
self
.
decoder
(
dec_embedding_output
,
dec_attn_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
)
if
self
.
add_pooler
and
self
.
post_process
:
return
decoder_output
,
encoder_output
,
pooled_output
else
:
return
decoder_output
,
encoder_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load."""
state_dict_
=
{}
if
mpu
.
is_pipeline_first_stage
()
:
if
self
.
pre_process
:
state_dict_
[
self
.
_embedding_key
]
\
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_
transform
er_key
]
\
=
self
.
transform
er
.
state_dict_for_save_checkpoint
(
state_dict_
[
self
.
_
encod
er_key
]
\
=
self
.
encod
er
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
state_dict_
[
self
.
_pooler_key
]
\
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
if
self
.
post_process
:
if
self
.
add_pooler
:
state_dict_
[
self
.
_pooler_key
]
\
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
add_decoder
:
state_dict_
[
self
.
_decoder_key
]
\
=
self
.
decoder
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
...
...
@@ -360,7 +413,7 @@ class TransformerLanguageModelBase(MegatronModule):
"""Customized load."""
# Embedding.
if
mpu
.
is_pipeline_first_stage
()
:
if
self
.
pre_process
:
if
self
.
_embedding_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_embedding_key
]
else
:
...
...
@@ -371,130 +424,41 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_
[
key
]
=
state_dict
[
key
]
self
.
embedding
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Transformer.
if
self
.
_transformer_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_transformer_key
]
# Encoder.
if
self
.
_encoder_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_encoder_key
]
# for backward compatibility.
elif
'transformer'
in
state_dict
:
state_dict_
=
state_dict
[
'transformer'
]
else
:
# for backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
'transformer.'
in
key
:
state_dict_
[
key
.
split
(
'transformer.'
)[
1
]]
=
state_dict
[
key
]
self
.
transformer
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Pooler.
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_pooler
:
assert
'pooler'
in
state_dict
,
\
# for backward compatibility.
state_dict_self_attention
=
{}
for
key
in
state_dict_
.
keys
():
if
'.attention.'
in
key
:
state_dict_self_attention
[
key
.
replace
(
".attention."
,
".self_attention."
)]
=
state_dict_
[
key
]
else
:
state_dict_self_attention
[
key
]
=
state_dict_
[
key
]
state_dict_
=
state_dict_self_attention
self
.
encoder
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
if
self
.
post_process
:
# pooler
if
self
.
add_pooler
:
assert
'pooler'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
strict
=
strict
)
# decoder
if
self
.
add_decoder
:
assert
'decoder'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
strict
=
strict
)
class
TransformerLanguageModel
(
TransformerLanguageModelBase
):
"""Transformer language model (see TransformerLanguageModelBase
for description of arguments).
"""
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
,
num_tokentypes
=
0
,
add_pooler
=
False
):
super
(
TransformerLanguageModel
,
self
).
__init__
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
add_pooler
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
):
return
super
(
TransformerLanguageModel
,
self
).
forward
(
(
input_ids
,
position_ids
),
attention_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
pooling_sequence_index
=
pooling_sequence_index
)
class
TransformerLanguageModelFirstStage
(
TransformerLanguageModelBase
):
"""Transformer language model, first stage (see
TransformerLanguageModelBase for description of arguments).
"""
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
,
num_tokentypes
=
0
):
super
(
TransformerLanguageModelFirstStage
,
self
).
__init__
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
return
super
(
TransformerLanguageModelFirstStage
,
self
).
forward
(
(
input_ids
,
position_ids
),
attention_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
class
TransformerLanguageModelIntermediateStage
(
TransformerLanguageModelBase
):
"""Transformer language model, intermediate stage (see
TransformerLanguageModelBase for description of arguments).
"""
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
):
super
(
TransformerLanguageModelIntermediateStage
,
self
).
__init__
(
attention_mask_func
,
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
return
super
(
TransformerLanguageModelIntermediateStage
,
self
).
forward
(
hidden_states
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
class
TransformerLanguageModelLastStage
(
TransformerLanguageModelBase
):
"""Transformer language model, final stage (see
TransformerLanguageModelBase for description of arguments).
"""
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
,
add_pooler
=
False
):
super
(
TransformerLanguageModelLastStage
,
self
).
__init__
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
add_pooler
=
add_pooler
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
):
return
super
(
TransformerLanguageModelLastStage
,
self
).
forward
(
hidden_states
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
pooling_sequence_index
=
pooling_sequence_index
)
self
.
decoder
.
load_state_dict
(
state_dict
[
self
.
_decoder_key
],
strict
=
strict
)
megatron/model/module.py
View file @
aed2f75e
...
...
@@ -25,6 +25,13 @@ from megatron import mpu
_FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
_HALF_TYPES
=
(
torch
.
HalfTensor
,
torch
.
cuda
.
HalfTensor
)
_BF16_TYPES
=
(
torch
.
BFloat16Tensor
,
torch
.
cuda
.
BFloat16Tensor
)
def
param_is_not_shared
(
param
):
return
not
hasattr
(
param
,
'shared'
)
or
not
param
.
shared
class
MegatronModule
(
torch
.
nn
.
Module
):
...
...
@@ -44,9 +51,9 @@ class MegatronModule(torch.nn.Module):
def
word_embeddings_weight
(
self
):
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
return
self
.
language_model
.
embedding
.
word_embeddings
.
weight
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
not
self
.
share_word_embeddings
:
raise
Exception
(
'word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false'
)
...
...
@@ -60,6 +67,13 @@ class MegatronModule(torch.nn.Module):
if
not
self
.
share_word_embeddings
:
raise
Exception
(
'initialize_word_embeddings() was called but '
'share_word_embeddings is false'
)
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline
# parallelism there is nothing to do.
if
args
.
pipeline_model_parallel_size
==
1
:
return
# Parameters are shared between the word embeddings layer, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
...
...
@@ -73,22 +87,28 @@ class MegatronModule(torch.nn.Module):
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if
mpu
.
is_pipeline_last_stage
():
if
not
mpu
.
is_pipeline_first_stage
()
:
self
.
_word_embeddings_for_head_key
=
'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings
# weights to 0 here, then copy first
stage's weights using
# all_reduce below.
self
.
word_embeddings
=
mpu
.
VocabParallelEmbedding
(
args
.
padded_vocab_size
,
args
.
hidden_size
,
init_method
=
init_method_normal
(
args
.
init_method_std
)
)
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
assert
not
mpu
.
is_pipeline_first_stage
()
self
.
_word_embeddings_for_head_key
=
'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
#
stage's weights using
all_reduce below.
self
.
word_embeddings
=
mpu
.
VocabParallelEmbedding
(
args
.
padded_vocab_size
,
args
.
hidden_size
,
init_method
=
init_method_normal
(
args
.
init_method_std
))
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
# Ensure that first and last stages have the same initial parameter
# values.
if
mpu
.
is_pipeline_first_stage
()
or
mpu
.
is_pipeline_last_stage
():
torch
.
distributed
.
all_reduce
(
self
.
word_embeddings_weight
().
data
,
group
=
mpu
.
get_embedding_group
())
if
torch
.
distributed
.
is_initialized
():
if
mpu
.
is_pipeline_first_stage
()
or
mpu
.
is_pipeline_last_stage
():
torch
.
distributed
.
all_reduce
(
self
.
word_embeddings_weight
().
data
,
group
=
mpu
.
get_embedding_group
())
else
:
print
(
"WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
def
conversion_helper
(
val
,
conversion
):
...
...
@@ -102,44 +122,56 @@ def conversion_helper(val, conversion):
return
rtn
def
fp32_to_f
p
16
(
val
):
"""Convert fp32 `val` to fp16"""
def
fp32_to_f
loat
16
(
val
,
float16_convertor
):
"""Convert fp32 `val` to fp16
/bf16
"""
def
half_conversion
(
val
):
val_typecheck
=
val
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
val_typecheck
=
val
.
data
if
isinstance
(
val_typecheck
,
_FLOAT_TYPES
):
val
=
val
.
half
(
)
val
=
float16_convertor
(
val
)
return
val
return
conversion_helper
(
val
,
half_conversion
)
def
f
p
16_to_fp32
(
val
):
"""Convert fp16 `val` to fp32"""
def
f
loat
16_to_fp32
(
val
):
"""Convert fp16
/bf16
`val` to fp32"""
def
float_conversion
(
val
):
val_typecheck
=
val
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
val_typecheck
=
val
.
data
if
isinstance
(
val_typecheck
,
_HALF_TYPES
):
if
isinstance
(
val_typecheck
,
(
_BF16_TYPES
,
_HALF_TYPES
)
)
:
val
=
val
.
float
()
return
val
return
conversion_helper
(
val
,
float_conversion
)
class
FP16Module
(
MegatronModule
):
class
Float16Module
(
MegatronModule
):
def
__init__
(
self
,
module
,
args
):
super
(
Float16Module
,
self
).
__init__
()
if
args
.
fp16
:
self
.
add_module
(
'module'
,
module
.
half
())
def
float16_convertor
(
val
):
return
val
.
half
()
elif
args
.
bf16
:
self
.
add_module
(
'module'
,
module
.
bfloat16
())
def
float16_convertor
(
val
):
return
val
.
bfloat16
()
else
:
raise
Exception
(
'should not be here'
)
def
__init__
(
self
,
module
):
super
(
FP16Module
,
self
).
__init__
()
self
.
add_module
(
'module'
,
module
.
half
())
self
.
float16_convertor
=
float16_convertor
def
forward
(
self
,
*
inputs
,
**
kwargs
):
if
mpu
.
is_pipeline_first_stage
():
inputs
=
fp32_to_f
p
16
(
inputs
)
inputs
=
fp32_to_f
loat
16
(
inputs
,
self
.
float16_convertor
)
outputs
=
self
.
module
(
*
inputs
,
**
kwargs
)
if
mpu
.
is_pipeline_last_stage
():
outputs
=
f
p
16_to_fp32
(
outputs
)
outputs
=
f
loat
16_to_fp32
(
outputs
)
return
outputs
...
...
megatron/model/multiple_choice.py
View file @
aed2f75e
...
...
@@ -19,7 +19,8 @@ import torch
from
megatron
import
get_args
,
print_rank_last
from
megatron
import
mpu
from
megatron.model.bert_model
import
bert_attention_mask_func
,
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.bert_model
import
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.language_model
import
get_language_model
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
...
...
@@ -27,29 +28,40 @@ from megatron.model.utils import scaled_init_method_normal
from
.module
import
MegatronModule
class
MultipleChoice
Base
(
MegatronModule
):
class
MultipleChoice
(
MegatronModule
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
MultipleChoiceBase
,
self
).
__init__
(
share_word_embeddings
=
False
)
def
__init__
(
self
,
num_tokentypes
=
2
,
pre_process
=
True
,
post_process
=
True
):
super
(
MultipleChoice
,
self
).
__init__
(
share_word_embeddings
=
False
)
args
=
get_args
()
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
# Multi-choice head.
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
self
.
multichoice_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
multichoice_head
=
get_linear_layer
(
args
.
hidden_size
,
1
,
init_method
)
self
.
_multichoice_head_key
=
'multichoice_head'
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
model_input
,
attention_mask
,
tokentype_ids
=
None
):
# [batch, choices, sequence] --> [batch * choices, sequence] -->
...
...
@@ -63,22 +75,21 @@ class MultipleChoiceBase(MegatronModule):
attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
kwargs
=
{}
if
mpu
.
is_pipeline_first_stage
():
input_ids
=
model_input
# Do the same as attention_mask for input_ids, tokentype_ids
assert
len
(
input_ids
.
shape
)
==
3
assert
len
(
tokentype_ids
.
shape
)
==
3
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
tokentype_ids
=
tokentype_ids
.
view
(
-
1
,
tokentype_ids
.
size
(
-
1
))
position_ids
=
bert_position_ids
(
input_ids
)
args
=
[
input_ids
,
position_ids
,
extended_attention_mask
]
kwargs
[
'tokentype_ids'
]
=
tokentype_ids
else
:
args
=
[
model_input
,
extended_attention_mask
]
lm_output
=
self
.
language_model
(
*
args
,
**
kwargs
)
if
mpu
.
is_pipeline_last_stage
():
input_ids
=
model_input
# Do the same as attention_mask for input_ids, tokentype_ids
assert
len
(
input_ids
.
shape
)
==
3
assert
len
(
tokentype_ids
.
shape
)
==
3
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
tokentype_ids
=
tokentype_ids
.
view
(
-
1
,
tokentype_ids
.
size
(
-
1
))
position_ids
=
bert_position_ids
(
input_ids
)
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
if
self
.
post_process
:
_
,
pooled_output
=
lm_output
multichoice_output
=
self
.
multichoice_dropout
(
pooled_output
)
multichoice_logits
=
self
.
multichoice_head
(
multichoice_output
)
...
...
@@ -98,7 +109,7 @@ class MultipleChoiceBase(MegatronModule):
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
state_dict_
[
self
.
_multichoice_head_key
]
\
=
self
.
multichoice_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
...
...
@@ -109,7 +120,7 @@ class MultipleChoiceBase(MegatronModule):
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
if
self
.
_multichoice_head_key
in
state_dict
:
self
.
multichoice_head
.
load_state_dict
(
state_dict
[
self
.
_multichoice_head_key
],
strict
=
strict
)
...
...
@@ -117,54 +128,3 @@ class MultipleChoiceBase(MegatronModule):
print_rank_last
(
'***WARNING*** could not find {} in the checkpoint, '
'initializing to random'
.
format
(
self
.
_multichoice_head_key
))
class
MultipleChoice
(
MultipleChoiceBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
MultipleChoice
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
return
super
(
MultipleChoice
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
)
class
MultipleChoiceFirstStage
(
MultipleChoiceBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
MultipleChoiceFirstStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
return
super
(
MultipleChoiceFirstStage
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
)
class
MultipleChoiceIntermediateStage
(
MultipleChoiceBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
MultipleChoiceIntermediateStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
):
return
super
(
MultipleChoiceIntermediateStage
,
self
).
forward
(
hidden_state
,
attention_mask
)
class
MultipleChoiceLastStage
(
MultipleChoiceBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
MultipleChoiceLastStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
):
return
super
(
MultipleChoiceLastStage
,
self
).
forward
(
hidden_state
,
attention_mask
)
megatron/model/realm_model.py
View file @
aed2f75e
...
...
@@ -6,11 +6,12 @@ from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoi
from
megatron.model
import
BertModel
from
.module
import
MegatronModule
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.language_model
import
get_language_model
from
megatron.model.utils
import
scaled_init_method_normal
from
megatron.model.bert_model
import
bert_attention_mask_func
,
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.bert_model
import
bert_extended_attention_mask
,
bert_position_ids
def
general_ict_model_provider
(
only_query_model
=
False
,
only_block_model
=
False
):
...
...
@@ -156,9 +157,9 @@ class IREncoderBertModel(MegatronModule):
args
.
num_layers
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
...
...
megatron/model/transformer.py
View file @
aed2f75e
...
...
@@ -14,7 +14,6 @@
# limitations under the License.
"""Transformer."""
import
math
import
torch
import
torch.nn.functional
as
F
...
...
@@ -22,11 +21,11 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
mpu
from
.module
import
MegatronModule
from
megatron.
checkpointing
import
get_checkpoint_version
from
megatron.model
import
import_l
ayer
n
orm
from
megatron.
model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model
import
L
ayer
N
orm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
# flags required to enable jit fusion kernels
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
...
...
@@ -47,12 +46,6 @@ torch._C._jit_override_can_fuse_on_gpu(True)
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
"""
class
ParallelMLP
(
MegatronModule
):
...
...
@@ -71,7 +64,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h.
self
.
dense_h_to_4h
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
4
*
args
.
hidden_size
,
args
.
ffn_
hidden_size
,
gather_output
=
False
,
init_method
=
init_method
,
skip_bias_add
=
True
)
...
...
@@ -85,12 +78,12 @@ class ParallelMLP(MegatronModule):
# Project back to h.
self
.
dense_4h_to_h
=
mpu
.
RowParallelLinear
(
4
*
args
.
hidden_size
,
args
.
ffn_
hidden_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
def
forward
(
self
,
hidden_states
):
...
...
@@ -109,41 +102,61 @@ class ParallelMLP(MegatronModule):
return
output
,
output_bias
class
Parallel
Self
Attention
(
MegatronModule
):
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
):
super
(
ParallelSelfAttention
,
self
).
__init__
()
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
AttnMaskType
.
padding
):
super
(
ParallelAttention
,
self
).
__init__
()
args
=
get_args
()
self
.
fp16
=
args
.
fp16
self
.
bf16
=
args
.
bf16
self
.
attention_mask_func
=
attention_mask_func
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
# Per attention head and per partition values.
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_partition
=
mpu
.
divide
(
args
.
hidde
n_size
,
self
.
hidden_size_per_partition
=
mpu
.
divide
(
projectio
n_size
,
world_size
)
self
.
hidden_size_per_attention_head
=
mpu
.
divide
(
args
.
hidde
n_size
,
args
.
num_attention_heads
)
projectio
n_size
,
args
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
mpu
.
divide
(
args
.
num_attention_heads
,
world_size
)
# Strided linear layer.
self
.
query_key_value
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
3
*
args
.
hidden_size
,
gather_output
=
False
,
init_method
=
init_method
)
if
attention_type
==
AttnType
.
self_attn
:
self
.
query_key_value
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
3
*
projection_size
,
gather_output
=
False
,
init_method
=
init_method
)
else
:
assert
attention_type
==
AttnType
.
cross_attn
self
.
query
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
projection_size
,
gather_output
=
False
,
init_method
=
init_method
)
self
.
key_value
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
2
*
projection_size
,
gather_output
=
False
,
init_method
=
init_method
)
coeff
=
None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
...
...
@@ -152,10 +165,10 @@ class ParallelSelfAttention(MegatronModule):
self
.
norm_factor
*=
coeff
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
args
.
scaled_upper_triang_masked_softmax_fusion
,
args
.
scaled_
masked_softmax_fusion
,
self
.
attention_mask_func
,
self
.
fp16
,
self
.
bf16
,
self
.
attn_mask_type
,
args
.
masked_softmax_fusion
,
attention_mask_func
,
self
.
attention_softmax_in_fp32
,
coeff
)
...
...
@@ -166,72 +179,55 @@ class ParallelSelfAttention(MegatronModule):
# Output.
self
.
dense
=
mpu
.
RowParallelLinear
(
args
.
hidde
n_size
,
projectio
n_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
def
_transpose_last_dim
(
self
,
mixed_layer
,
num_splits
,
num_splits_first
):
input_shape
=
mixed_layer
.
size
();
if
num_splits_first
:
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
,
encoder_output
=
None
):
# hidden_states: [sq, b, h]
intermediate_shape
=
input_shape
[:
-
1
]
+
\
(
num_splits
,
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
# =====================
# Query, Key, and Value
# =====================
mixed_layer
=
mixed_layer
.
view
(
*
intermediate_shape
)
mixed_layer
=
mixed_layer
.
transpose
(
-
2
,
-
3
).
contiguous
()
else
:
"""[s, b, np * hn * num_splits]
-->(view) [s, b, np, hn, num_splits]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
if
self
.
attention_type
==
AttnType
.
self_attn
:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
intermediate_shape
=
input_shape
[:
-
1
]
+
\
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
num_splits
)
3
*
self
.
hidden_size_per_attention_head
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
mixed_layer
=
mixed_layer
.
view
(
*
intermediate_shape
)
mixed_layer
=
mixed_layer
.
transpose
(
-
1
,
-
2
).
contiguous
()
mixed_layer
=
mixed_layer
.
view
(
*
input_shape
)
return
mixed_layer
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(
query_layer
,
key_layer
,
value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
else
:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer
,
_
=
self
.
key_value
(
encoder_output
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
# hidden_states: [sq, b, h]
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape
=
mixed_kv_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
2
*
self
.
hidden_size_per_attention_head
)
mixed_kv_layer
=
mixed_kv_layer
.
view
(
*
new_tensor_shape
)
# =====================
# Query, Key, and Value
# =====================
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(
key_layer
,
value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
checkpoint_version
=
get_checkpoint_version
()
if
checkpoint_version
is
not
None
:
if
checkpoint_version
==
0
:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
,
3
,
True
)
elif
checkpoint_version
==
1.0
:
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
,
3
,
False
)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
3
*
self
.
hidden_size_per_attention_head
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(
query_layer
,
key_layer
,
value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer
,
_
=
self
.
query
(
hidden_states
)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape
=
query_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ==================================
# Adjust key and value for inference
...
...
@@ -246,41 +242,41 @@ class ParallelSelfAttention(MegatronModule):
if
get_key_value
:
present
=
(
key_layer
,
value_layer
)
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting result tensor: [b * np, sq, sk]
matmul_result
=
torch
.
empty
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
],
dtype
=
query_layer
.
dtype
,
dtype
=
query_layer
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
#[b * np, hn, sk]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
#
[b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
...
...
@@ -298,7 +294,6 @@ class ParallelSelfAttention(MegatronModule):
:
attention_scores
.
size
(
3
),
:
attention_scores
.
size
(
3
)]
# ===========================
# Attention probs and dropout
# ===========================
...
...
@@ -312,7 +307,6 @@ class ParallelSelfAttention(MegatronModule):
with
mpu
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# =========================
# Context layer. [sq, b, hp]
# =========================
...
...
@@ -321,21 +315,21 @@ class ParallelSelfAttention(MegatronModule):
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
))
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
))
# change view [sk, b * np, hn]
# change view [sk, b * np, hn]
value_layer
=
value_layer
.
view
(
value_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# change view [b * np, sq, sk]
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
# matmul: [b * np, sq, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
# change view [b, np, sq, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
...
...
@@ -348,7 +342,6 @@ class ParallelSelfAttention(MegatronModule):
(
self
.
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
# =================
# Output. [sq, b, h]
# =================
...
...
@@ -361,7 +354,7 @@ class ParallelSelfAttention(MegatronModule):
return
output
,
bias
def
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
)
:
def
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out
=
torch
.
nn
.
functional
.
dropout
(
x
+
bias
,
p
=
prob
,
training
=
training
)
out
=
residual
+
out
...
...
@@ -375,13 +368,13 @@ def get_bias_dropout_add(training):
@
torch
.
jit
.
script
def
bias_dropout_add_fused_train
(
x
,
bias
,
residual
,
prob
)
:
def
bias_dropout_add_fused_train
(
x
,
bias
,
residual
,
prob
):
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
True
)
@
torch
.
jit
.
script
def
bias_dropout_add_fused_inference
(
x
,
bias
,
residual
,
prob
)
:
def
bias_dropout_add_fused_inference
(
x
,
bias
,
residual
,
prob
):
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
False
)
...
...
@@ -389,66 +382,85 @@ def bias_dropout_add_fused_inference(x, bias, residual, prob) :
class
ParallelTransformerLayer
(
MegatronModule
):
"""A single transformer layer.
Transform
or
e layer takes input with size [b, s, h] and returns an
Transforme
r
layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Layernorm on the input data.
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
# Self attention.
self
.
attention
=
ParallelSelfAttention
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
)
self
.
self_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
# Layernorm on the
input data.
# Layernorm on the
attention output
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
cross_attn
)
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
# MLP
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
# hidden_states: [b, s, h]
# Layer norm at the begining of the transformer layer.
# Layer norm at the begin
n
ing of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
,
attention_bias
=
\
self
.
attention
(
layernorm_output
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
self
.
self_
attention
(
layernorm_output
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
if
get_key_value
:
attention_output
,
presents
=
attention_output
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
...
...
@@ -459,7 +471,7 @@ class ParallelTransformerLayer(MegatronModule):
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
#re-enable torch grad to enable fused optimization.
#
re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
...
...
@@ -470,16 +482,38 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
if
self
.
layer_type
==
LayerType
.
decoder
:
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
layernorm_output
,
enc_dec_attn_mask
,
encoder_output
=
encoder_output
)
# residual connection
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
#re-enable torch grad to enable fused optimization.
#
re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
...
...
@@ -496,12 +530,18 @@ class ParallelTransformerLayer(MegatronModule):
class
ParallelTransformer
(
MegatronModule
):
"""Transformer class."""
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
True
,
post_process
=
True
):
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
# Store activation checkpoiting flag.
self
.
checkpoint_activations
=
args
.
checkpoint_activations
...
...
@@ -515,15 +555,38 @@ class ParallelTransformer(MegatronModule):
# Transformer layers.
def
build_layer
(
layer_number
):
return
ParallelTransformerLayer
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
)
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
)
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'virtual_pipeline_model_parallel_size'
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset
=
mpu
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
# Each stage gets a contiguous set of layers.
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
# Final layer norm before output.
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
...
...
@@ -531,14 +594,18 @@ class ParallelTransformer(MegatronModule):
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
):
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
):
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
):
def
custom_forward
(
*
inputs
):
x_
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
encoder_output
=
inputs
[
2
]
enc_dec_attn_mask
=
inputs
[
3
]
for
index
in
range
(
start
,
end
):
layer
=
self
.
_get_layer
(
index
)
x_
=
layer
(
x_
,
inputs
[
1
]
)
x_
=
layer
(
x_
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
return
x_
return
custom_forward
...
...
@@ -548,13 +615,23 @@ class ParallelTransformer(MegatronModule):
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
hidden_states
,
attention_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
checkpoint_num_layers
return
hidden_states
def
set_input_tensor
(
self
,
input_tensor
):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
get_key_value
=
False
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
):
# Checks.
if
layer_past
is
not
None
:
...
...
@@ -566,7 +643,7 @@ class ParallelTransformer(MegatronModule):
'get_key_value does not work with '
\
'activation checkpointing'
if
mpu
.
is_pipeline_first_stage
()
:
if
self
.
pre_process
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
fp32_residual_connection
:
...
...
@@ -574,10 +651,18 @@ class ParallelTransformer(MegatronModule):
# Otherwise, leave it as is.
else
:
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
else
:
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
checkpoint_activations
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
)
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
if
get_key_value
:
presents
=
[]
...
...
@@ -588,14 +673,16 @@ class ParallelTransformer(MegatronModule):
past
=
layer_past
[
index
]
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
layer_past
=
past
,
get_key_value
=
get_key_value
)
if
get_key_value
:
hidden_states
,
present
=
hidden_states
presents
.
append
(
present
)
# Final layer norm.
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
self
.
final_layernorm
(
hidden_states
)
...
...
megatron/model/utils.py
View file @
aed2f75e
...
...
@@ -39,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers):
return
init_
def
attention_mask_func
(
attention_scores
,
attention_mask
):
attention_scores
.
masked_fill_
(
attention_mask
,
-
10000.0
)
return
attention_scores
def
get_linear_layer
(
rows
,
columns
,
init_method
):
"""Simple linear layer with weight initialization."""
layer
=
torch
.
nn
.
Linear
(
rows
,
columns
)
...
...
megatron/model/vit_model.py
0 → 100644
View file @
aed2f75e
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model."""
import
math
import
einops
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
(
get_linear_layer
,
init_method_normal
,
scaled_init_method_normal
,
)
from
.module
import
MegatronModule
class
VitMlpHead
(
MegatronModule
):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def
__init__
(
self
,
hidden_size
,
num_classes
):
super
(
VitMlpHead
,
self
).
__init__
()
self
.
dense_in
=
torch
.
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
dense_out
=
torch
.
nn
.
Linear
(
hidden_size
,
num_classes
)
torch
.
nn
.
init
.
constant_
(
self
.
dense_out
.
bias
,
-
10
)
def
forward
(
self
,
hidden_states
,
sequence_index
=
0
):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
x
=
hidden_states
[:,
sequence_index
,
:]
x
=
self
.
dense_in
(
x
)
x
=
torch
.
tanh
(
x
)
x
=
self
.
dense_out
(
x
)
return
x
def
twod_interpolate_position_embeddings_hook
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
,
):
args
=
get_args
()
num_patches_per_dim
=
args
.
img_dim
//
args
.
patch_dim
num_patches
=
num_patches_per_dim
**
2
seq_length
=
num_patches
+
1
hidden_size
=
args
.
hidden_size
key
=
prefix
+
"weight"
# import pdb
# pdb.set_trace()
assert
key
in
state_dict
if
key
in
state_dict
:
input_param
=
state_dict
[
key
]
assert
input_param
.
shape
[
1
]
==
hidden_size
if
input_param
.
shape
[
0
]
!=
seq_length
:
# update input_param and load it to state_dict[key]
num_tok_input
=
input_param
.
shape
[
0
]
-
1
num_tok_new
=
seq_length
-
1
input_param_tok
,
input_param_grid
=
(
input_param
[:
1
,
:],
input_param
[
1
:,
:],
)
gs_input
=
int
(
math
.
sqrt
(
num_tok_input
))
gs_new
=
int
(
math
.
sqrt
(
num_tok_new
))
input_param_grid
=
input_param_grid
.
transpose
(
0
,
1
).
contiguous
()
input_param_grid
=
input_param_grid
.
reshape
(
(
1
,
-
1
,
gs_input
,
gs_input
)
)
input_param_grid
=
input_param_grid
.
float
()
scale_factor
=
gs_new
/
gs_input
input_param_grid
=
F
.
interpolate
(
input_param_grid
,
scale_factor
=
scale_factor
,
mode
=
"bilinear"
)
input_param_grid
=
input_param_grid
.
half
()
input_param_grid
=
input_param_grid
.
reshape
((
-
1
,
gs_new
*
gs_new
))
input_param_grid
=
input_param_grid
.
transpose
(
0
,
1
).
contiguous
()
assert
input_param_grid
.
shape
[
1
]
==
hidden_size
input_param
=
torch
.
cat
((
input_param_tok
,
input_param_grid
),
dim
=
0
)
assert
(
input_param
.
shape
[
0
]
==
seq_length
and
input_param
.
shape
[
1
]
==
hidden_size
)
state_dict
[
key
]
=
input_param
class
VitModel
(
MegatronModule
):
"""Vision Transformer Model."""
def
__init__
(
self
,
num_classes
,
finetune
=
False
):
super
(
VitModel
,
self
).
__init__
()
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
if
args
.
init_method_xavier_uniform
:
self
.
init_method
=
torch
.
nn
.
init
.
xavier_uniform_
self
.
scaled_init_method
=
torch
.
nn
.
init
.
xavier_uniform_
else
:
self
.
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
patch_dim
=
args
.
patch_dim
self
.
img_dim
=
args
.
img_dim
self
.
finetune
=
finetune
assert
self
.
img_dim
%
self
.
patch_dim
==
0
self
.
num_patches_per_dim
=
self
.
img_dim
//
self
.
patch_dim
self
.
num_patches
=
self
.
num_patches_per_dim
**
2
self
.
seq_length
=
self
.
num_patches
+
1
self
.
flatten_dim
=
self
.
patch_dim
*
self
.
patch_dim
*
args
.
num_channels
# cls_token
self
.
cls_token
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
1
,
1
,
self
.
hidden_size
))
torch
.
nn
.
init
.
zeros_
(
self
.
cls_token
)
# Linear encoder
self
.
linear_encoder
=
torch
.
nn
.
Linear
(
self
.
flatten_dim
,
self
.
hidden_size
)
# embedding
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
self
.
seq_length
,
self
.
hidden_size
)
init_method_normal
(
args
.
init_method_std
)(
self
.
position_embeddings
.
weight
)
self
.
position_ids
=
torch
.
arange
(
self
.
seq_length
).
expand
(
1
,
-
1
).
cuda
()
self
.
position_embeddings
.
_register_load_state_dict_pre_hook
(
twod_interpolate_position_embeddings_hook
)
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
# Transformer
self
.
transformer
=
ParallelTransformer
(
self
.
init_method
,
self
.
scaled_init_method
)
# MLP head
if
not
self
.
finetune
:
self
.
mlp_head
=
VitMlpHead
(
self
.
hidden_size
,
self
.
num_classes
)
else
:
self
.
class_head
=
get_linear_layer
(
self
.
hidden_size
,
num_classes
,
torch
.
nn
.
init
.
zeros_
)
def
forward
(
self
,
x
):
x
=
einops
.
rearrange
(
x
,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)"
,
p1
=
self
.
patch_dim
,
p2
=
self
.
patch_dim
,
)
assert
x
.
dtype
==
torch
.
half
x
=
self
.
linear_encoder
(
x
)
cls_tokens
=
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
x
=
x
+
self
.
position_embeddings
(
self
.
position_ids
)
x
=
self
.
embedding_dropout
(
x
)
x
=
self
.
transformer
(
x
,
None
)
if
not
self
.
finetune
:
x
=
self
.
mlp_head
(
x
)
else
:
x
=
self
.
class_head
(
x
[:,
0
,
:])
return
x
megatron/mpu/__init__.py
View file @
aed2f75e
...
...
@@ -38,13 +38,15 @@ from .initialize import get_pipeline_model_parallel_next_rank
from
.initialize
import
get_pipeline_model_parallel_prev_rank
from
.initialize
import
get_tensor_model_parallel_world_size
,
set_tensor_model_parallel_world_size
from
.initialize
import
get_pipeline_model_parallel_world_size
,
set_pipeline_model_parallel_world_size
from
.initialize
import
get_virtual_pipeline_model_parallel_rank
,
set_virtual_pipeline_model_parallel_rank
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
from
.layers
import
ColumnParallelLinear
from
.layers
import
RowParallelLinear
from
.layers
import
VocabParallelEmbedding
from
.layers
import
(
set_defaults_if_not_set_tensor_model_parallel_attributes
,
from
.layers
import
(
set_tensor_model_parallel_attributes
,
set_defaults_if_not_set_tensor_model_parallel_attributes
,
copy_tensor_model_parallel_attributes
)
from
.mappings
import
copy_to_tensor_model_parallel_region
...
...
@@ -57,6 +59,8 @@ from .random import get_cuda_rng_tracker
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.random
import
gather_split_1d_tensor
from
.random
import
split_tensor_into_1d_equal_chunks
from
.utils
import
divide
from
.utils
import
split_tensor_along_last_dim
megatron/mpu/data.py
View file @
aed2f75e
...
...
@@ -20,7 +20,7 @@ from .initialize import get_tensor_model_parallel_rank
from
.initialize
import
get_tensor_model_parallel_src_rank
_MAX_DATA_DIM
=
4
_MAX_DATA_DIM
=
5
def
_check_data_types
(
keys
,
data
,
target_dtype
):
...
...
megatron/mpu/initialize.py
View file @
aed2f75e
...
...
@@ -32,6 +32,9 @@ _EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
...
...
@@ -48,7 +51,8 @@ def is_unitialized():
def
initialize_model_parallel
(
tensor_model_parallel_size_
=
1
,
pipeline_model_parallel_size_
=
1
):
pipeline_model_parallel_size_
=
1
,
virtual_pipeline_model_parallel_size_
=
None
):
"""
Initialize model data parallel groups.
...
...
@@ -91,6 +95,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
num_pipeline_model_parallel_groups
=
world_size
//
pipeline_model_parallel_size
num_data_parallel_groups
=
world_size
//
data_parallel_size
if
virtual_pipeline_model_parallel_size_
is
not
None
:
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
virtual_pipeline_model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
# Build the data-parallel groups.
...
...
@@ -258,17 +268,46 @@ def get_pipeline_model_parallel_rank():
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
def
is_pipeline_first_stage
():
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
if
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
and
\
get_virtual_pipeline_model_parallel_rank
()
!=
0
:
return
False
return
get_pipeline_model_parallel_rank
()
==
0
def
is_pipeline_last_stage
():
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
virtual_pipeline_model_parallel_world_size
=
\
get_virtual_pipeline_model_parallel_world_size
()
if
virtual_pipeline_model_parallel_world_size
is
not
None
and
\
get_virtual_pipeline_model_parallel_rank
()
!=
(
virtual_pipeline_model_parallel_world_size
-
1
):
return
False
return
get_pipeline_model_parallel_rank
()
==
(
get_pipeline_model_parallel_world_size
()
-
1
)
def
get_virtual_pipeline_model_parallel_rank
():
"""Return the virtual pipeline-parallel rank."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
def
set_virtual_pipeline_model_parallel_rank
(
rank
):
"""Set the virtual pipeline-parallel rank."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
rank
def
get_virtual_pipeline_model_parallel_world_size
():
"""Return the virtual pipeline-parallel world size."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def
get_tensor_model_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
...
...
@@ -276,11 +315,13 @@ def get_tensor_model_parallel_src_rank():
local_world_size
=
get_tensor_model_parallel_world_size
()
return
(
global_rank
//
local_world_size
)
*
local_world_size
def
get_pipeline_model_parallel_first_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
return
_PIPELINE_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_last_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
...
...
@@ -294,6 +335,7 @@ def get_pipeline_model_parallel_next_rank():
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
def
get_pipeline_model_parallel_prev_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
...
...
@@ -301,6 +343,7 @@ def get_pipeline_model_parallel_prev_rank():
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
def
get_data_parallel_world_size
():
"""Return world size for the data parallel group."""
return
torch
.
distributed
.
get_world_size
(
group
=
get_data_parallel_group
())
...
...
megatron/mpu/layers.py
View file @
aed2f75e
...
...
@@ -43,6 +43,12 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_stride'
:
1
}
def
param_is_not_tensor_parallel_duplicate
(
param
):
return
(
hasattr
(
param
,
'tensor_model_parallel'
)
and
param
.
tensor_model_parallel
)
or
(
get_tensor_model_parallel_rank
()
==
0
)
def
set_tensor_model_parallel_attributes
(
tensor
,
is_parallel
,
dim
,
stride
):
# Make sure the attributes are not set.
for
attribute
in
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
:
...
...
@@ -260,9 +266,7 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
output_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
args
.
params_dtype
))
self
.
bias
.
tensor_model_parallel
=
True
self
.
bias
.
partition_dim
=
0
self
.
bias
.
stride
=
stride
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
stride
)
# Always initialize bias to zero.
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
...
...
megatron/optimizer/__init__.py
View file @
aed2f75e
...
...
@@ -14,35 +14,35 @@
# limitations under the License.
from
apex.optimizers
import
FusedAdam
as
Adam
from
apex.optimizers
import
FusedSGD
as
SGD
from
megatron
import
get_args
from
megatron.model
import
import_l
ayer
n
orm
from
megatron.model
import
L
ayer
N
orm
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.optimizer
import
F
P
16OptimizerWithF
P
16Params
,
FP32Optimizer
from
.optimizer
import
F
loat
16OptimizerWithF
loat
16Params
,
FP32Optimizer
def
_get_params_for_weight_decay_optimization
(
module
):
def
_get_params_for_weight_decay_optimization
(
module
s
):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
args
=
get_args
()
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
])
else
:
weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
'bias'
])
for
module
in
modules
:
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
])
else
:
weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
'bias'
])
return
weight_decay_params
,
no_weight_decay_params
...
...
@@ -52,28 +52,58 @@ def get_megatron_optimizer(model):
# Base optimizer.
param_groups
=
_get_params_for_weight_decay_optimization
(
model
)
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
eps
=
args
.
adam_eps
)
if
args
.
optimizer
==
'adam'
:
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
eps
=
args
.
adam_eps
)
elif
args
.
optimizer
==
'sgd'
:
optimizer
=
SGD
(
param_groups
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
momentum
=
args
.
sgd_momentum
)
else
:
raise
Exception
(
'{} optimizer is not supported.'
.
format
(
args
.
optimizer
))
# Determine whether the params have main-grad field.
params_have_main_grad
=
False
if
args
.
DDP_impl
==
'local'
:
params_have_main_grad
=
True
if
args
.
fp16
:
if
args
.
fp16
or
args
.
bf16
:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler
=
None
# Constant loss scale.
if
args
.
loss_scale
:
grad_scaler
=
ConstantGradScaler
(
args
.
loss_scale
)
# Dynamic loss scale.
else
:
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
args
.
initial_loss_scale
,
min_scale
=
args
.
min_loss_scale
,
growth_factor
=
2.0
,
backoff_factor
=
0.5
,
growth_interval
=
args
.
loss_scale_window
,
hysteresis
=
args
.
hysteresis
)
if
args
.
fp16
:
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
args
.
initial_loss_scale
,
min_scale
=
args
.
min_loss_scale
,
growth_factor
=
2.0
,
backoff_factor
=
0.5
,
growth_interval
=
args
.
loss_scale_window
,
hysteresis
=
args
.
hysteresis
)
# Megatron optimizer.
return
FP16OptimizerWithFP16Params
(
optimizer
,
grad_scaler
,
args
.
clip_grad
)
return
Float16OptimizerWithFloat16Params
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
bf16
,
grad_scaler
)
# FP32.
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
)
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
)
megatron/optimizer/clip_grads.py
View file @
aed2f75e
...
...
@@ -22,6 +22,8 @@ from apex.multi_tensor_apply import multi_tensor_applier
import
amp_C
from
megatron
import
mpu
from
megatron.model.module
import
param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
def
clip_grad_norm_fp32
(
parameters
,
max_norm
,
norm_type
=
2
):
...
...
@@ -54,9 +56,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
grads_for_norm
=
[]
for
param
in
parameters
:
grad_not_none
=
param
.
grad
is
not
None
is_not_shared
=
not
hasattr
(
param
,
'shared'
)
or
not
param
.
shared
is_not_tp_duplicate
=
param
.
tensor_model_parallel
or
\
(
mpu
.
get_tensor_model_parallel_rank
()
==
0
)
is_not_shared
=
param_is_not_shared
(
param
)
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
grad
=
param
.
grad
.
detach
()
if
grad_not_none
:
# Make sure the grads are in fp32
...
...
@@ -117,3 +118,31 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
clip_coeff
)
return
total_norm
def
count_zeros_fp32
(
parameters
):
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
total_num_zeros
=
0.0
for
param
in
parameters
:
grad_not_none
=
param
.
grad
is
not
None
is_not_shared
=
param_is_not_shared
(
param
)
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
if
grad_not_none
and
is_not_shared
and
is_not_tp_duplicate
:
grad
=
param
.
grad
.
detach
()
num_zeros
=
grad
.
numel
()
-
torch
.
count_nonzero
(
grad
)
total_num_zeros
=
num_zeros
+
total_num_zeros
# Sum across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
mpu
.
get_model_parallel_group
())
total_num_zeros
=
total_num_zeros
.
item
()
return
total_num_zeros
megatron/optimizer/optimizer.py
View file @
aed2f75e
...
...
@@ -27,7 +27,7 @@ from megatron import get_timers
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
.clip_grads
import
clip_grad_norm_fp32
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
def
_zero_grad_group_helper
(
group
,
set_to_none
):
...
...
@@ -46,49 +46,77 @@ def _zero_grad_group_helper(group, set_to_none):
def
_multi_tensor_copy_this_to_that
(
this
,
that
,
overflow_buf
=
None
):
"""Use multi-tensor-applier to copy values from one list to another."""
"""Use multi-tensor-applier to copy values from one list to another.
We don't have a blfoat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
with bfloat16."""
if
overflow_buf
:
overflow_buf
.
fill_
(
0
)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
overflow_buf
,
[
this
,
that
],
1.0
)
else
:
overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
overflow_buf
,
[
this
,
that
],
1.0
)
for
this_
,
that_
in
zip
(
this
,
that
):
that_
.
copy_
(
this_
)
class
MegatronOptimizer
(
ABC
):
def
__init__
(
self
,
optimizer
):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
):
"""Input optimizer is the base optimizer for example Adam."""
self
.
optimizer
=
optimizer
assert
self
.
optimizer
,
'no optimizer is provided.'
# Set gradient clipping and logging params.
self
.
clip_grad
=
clip_grad
self
.
log_num_zeros_in_grad
=
log_num_zeros_in_grad
self
.
params_have_main_grad
=
params_have_main_grad
def
clip_grad_norm
(
self
,
clip_grad
):
def
get_parameters
(
self
):
params
=
[]
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
params
.
append
(
param
)
clip_grad_norm_fp32
(
params
,
clip_grad
)
return
params
def
clip_grad_norm
(
self
,
clip_grad
):
params
=
self
.
get_parameters
()
return
clip_grad_norm_fp32
(
params
,
clip_grad
)
def
count_zeros
(
self
):
params
=
self
.
get_parameters
()
return
count_zeros_fp32
(
params
)
@
abstractmethod
def
zero_grad
(
self
,
set_to_none
=
True
):
pass
@
abstractmethod
def
get_loss_scale
(
self
):
"""The output should be a cuda tensor of size 1."""
pass
def
scale_loss
(
self
,
loss
):
"""Simple scaling."""
return
self
.
get_loss_scale
()
*
loss
@
abstractmethod
def
step
(
self
):
pass
@
abstractmethod
def
reload_model_params
(
self
):
"""Refreshes any internal state from the current model parameters.
...
...
@@ -98,14 +126,17 @@ class MegatronOptimizer(ABC):
with main parameters, the main parameters need to also be updated."""
pass
@
abstractmethod
def
state_dict
(
self
):
pass
@
abstractmethod
def
load_state_dict
(
self
,
state_dict
):
pass
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
def
_get_state
(
self
):
...
...
@@ -116,6 +147,7 @@ class MegatronOptimizer(ABC):
state
=
property
(
_get_state
,
_set_state
)
# Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
...
...
@@ -129,49 +161,90 @@ class MegatronOptimizer(ABC):
class
FP16OptimizerWithFP16Params
(
MegatronOptimizer
):
def
__init__
(
self
,
optimizer
,
grad_scaler
,
clip_grad
):
super
(
FP16OptimizerWithFP16Params
,
self
).
__init__
(
optimizer
)
class
Float16OptimizerWithFloat16Params
(
MegatronOptimizer
):
"""Float16 optimizer for fp16 and bf16 data types.
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a contihuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
"""
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
bf16
,
grad_scaler
):
super
(
Float16OptimizerWithFloat16Params
,
self
).
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
)
self
.
bf16
=
bf16
self
.
grad_scaler
=
grad_scaler
self
.
clip_grad
=
clip_grad
# None grad scaler is only supported for bf16.
if
self
.
grad_scaler
is
None
:
assert
self
.
bf16
,
'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
self
.
found_inf
=
torch
.
cuda
.
FloatTensor
([
0.0
])
# Note that we keep this for the cases that grad scaler is none.
# We still record nan/inf if we have a bfloat16 with a grad scaler.
if
self
.
grad_scaler
:
self
.
found_inf
=
torch
.
cuda
.
FloatTensor
([
0.0
])
# Dummy tensor needed for apex multi-apply tensor.
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# For bfloat, we don't have multi-tensor apply and for now
# we set it to none so the multi-tensor apply gets ignored.
if
bf16
:
self
.
_dummy_overflow_buf
=
None
else
:
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# In case grad scaler is not passed, define the unity scale.
if
self
.
grad_scaler
is
None
:
self
.
_scale_one
=
torch
.
cuda
.
FloatTensor
([
1.0
])
# ======================
# main parameter stuff
# ======================
# Three groups of parameters:
# f
p
16_groups: original f
p
16 parameters
# fp32_from_f
p
16_groups: fp32 copy of f
p
16 parameters
# f
loat
16_groups: original f
loat
16 parameters
# fp32_from_f
loat
16_groups: fp32 copy of f
loat
16 parameters
# fp32_from_fp32_groups: original fp32 parameters
self
.
f
p
16_groups
=
[]
self
.
fp32_from_f
p
16_groups
=
[]
self
.
f
loat
16_groups
=
[]
self
.
fp32_from_f
loat
16_groups
=
[]
self
.
fp32_from_fp32_groups
=
[]
# For all the groups in the original optimizer:
for
param_group
in
self
.
optimizer
.
param_groups
:
f
p
16_params_this_group
=
[]
f
loat
16_params_this_group
=
[]
fp32_params_this_group
=
[]
fp32_from_f
p
16_params_this_group
=
[]
fp32_from_f
loat
16_params_this_group
=
[]
# For all the parameters in this group:
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
if
param
.
requires_grad
:
# fp16 params:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
fp16_params_this_group
.
append
(
param
)
# float16 params:
if
param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
]:
float16_params_this_group
.
append
(
param
)
# Create a copy
main_param
=
param
.
detach
().
clone
().
float
()
# Store grads
main_param
.
requires_grad
=
True
# Copy tensor model parallel attributes.
mpu
.
copy_tensor_model_parallel_attributes
(
main_param
,
param
)
...
...
@@ -179,7 +252,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
main_param
.
shared
=
param
.
shared
# Replace the optimizer params with the new fp32 copy.
param_group
[
'params'
][
i
]
=
main_param
fp32_from_f
p
16_params_this_group
.
append
(
main_param
)
fp32_from_f
loat
16_params_this_group
.
append
(
main_param
)
# Reset existing state dict key to the new main param.
if
param
in
self
.
optimizer
.
state
:
self
.
optimizer
.
state
[
main_param
]
\
...
...
@@ -191,13 +264,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
param_group
[
'params'
][
i
]
=
param
else
:
raise
TypeError
(
"Wrapped parameters must be either "
"torch.cuda.FloatTensor or "
"torch.cuda.HalfTensor. "
"Received {}"
.
format
(
param
.
type
()))
self
.
fp16_groups
.
append
(
fp16_params_this_group
)
self
.
fp32_from_fp16_groups
.
append
(
fp32_from_fp16_params_this_group
)
raise
TypeError
(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'
.
format
(
param
.
type
()))
self
.
float16_groups
.
append
(
float16_params_this_group
)
self
.
fp32_from_float16_groups
.
append
(
fp32_from_float16_params_this_group
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
# Leverage state_dict() and load_state_dict() to
...
...
@@ -207,37 +282,40 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
f
p
16_groups & fp32_from_fp32_groups."""
for
group
in
self
.
f
p
16_groups
:
f
loat
16_groups & fp32_from_fp32_groups."""
for
group
in
self
.
f
loat
16_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
for
group
in
self
.
fp32_from_fp32_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
def
get_loss_scale
(
self
):
if
self
.
grad_scaler
is
None
:
return
self
.
_scale_one
return
self
.
grad_scaler
.
scale
def
_copy_model_grads_to_main_grads
(
self
):
# This only needs to be done for the fp16 group.
model_grads
=
[]
main_grads
=
[]
for
model_group
,
main_group
in
zip
(
self
.
fp16_groups
,
self
.
fp32_from_fp16_groups
):
# This only needs to be done for the float16 group.
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
if
model_param
.
grad
is
not
None
:
if
main_param
.
grad
is
None
:
main_param
.
grad
=
torch
.
empty_like
(
main_param
)
model_grads
.
append
(
model_param
.
grad
.
data
)
main_grads
.
append
(
main_param
.
grad
.
data
)
_multi_tensor_copy_this_to_that
(
this
=
model_grads
,
that
=
main_grads
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
if
self
.
params_have_main_grad
:
main_param
.
grad
=
model_param
.
main_grad
.
float
()
else
:
if
model_param
.
grad
is
not
None
:
main_param
.
grad
=
model_param
.
grad
.
float
()
# For fp32 grads, we need to reset the grads to main grad.
if
self
.
params_have_main_grad
:
for
model_group
in
self
.
fp32_from_fp32_groups
:
for
model_param
in
model_group
:
model_param
.
grad
=
model_param
.
main_grad
def
_unscale_main_grads_and_check_for_nan
(
self
):
main_grads
=
[]
# fp32 params fromm f
p
16 ones.
for
main_group
in
self
.
fp32_from_f
p
16_groups
:
# fp32 params fromm f
loat
16 ones.
for
main_group
in
self
.
fp32_from_f
loat
16_groups
:
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
...
...
@@ -261,11 +339,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return
found_inf_flag
def
_get_model_and_main_params_data_f
p
16
(
self
):
def
_get_model_and_main_params_data_f
loat
16
(
self
):
model_data
=
[]
main_data
=
[]
for
model_group
,
main_group
in
zip
(
self
.
f
p
16_groups
,
self
.
fp32_from_f
p
16_groups
):
for
model_group
,
main_group
in
zip
(
self
.
f
loat
16_groups
,
self
.
fp32_from_f
loat
16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_data
.
append
(
model_param
.
data
)
main_data
.
append
(
main_param
.
data
)
...
...
@@ -273,15 +351,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def
_copy_main_params_to_model_params
(
self
):
# Only needed for the f
p
16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_f
p
16
()
# Only needed for the f
loat
16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_f
loat
16
()
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
def
_copy_model_params_to_main_params
(
self
):
# Only needed for the f
p
16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_f
p
16
()
# Only needed for the f
loat
16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_f
loat
16
()
_multi_tensor_copy_this_to_that
(
this
=
model_data
,
that
=
main_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
...
...
@@ -300,24 +378,34 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
self
.
_copy_model_grads_to_main_grads
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# Unscale and check for inf/nan.
timers
(
'optimizer-unscale-and-check-inf'
).
start
()
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if
self
.
grad_scaler
:
# Unscale and check for inf/nan.
timers
(
'optimizer-unscale-and-check-inf'
).
start
()
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
# We are done with scaling gradients
# so we can update the loss scale.
self
.
grad_scaler
.
update
(
found_inf_flag
)
# We are done with scaling gradients
# so we can update the loss scale.
self
.
grad_scaler
.
update
(
found_inf_flag
)
# If we found inf/nan, skip the update.
if
found_inf_flag
:
return
False
# If we found inf/nan, skip the update.
if
found_inf_flag
:
return
False
,
None
,
None
# Clip the main gradients.
timers
(
'optimizer-clip-main-grad'
).
start
()
self
.
clip_grad_norm
(
self
.
clip_grad
)
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
# count the zeros in the grads
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
# Step the optimizer.
self
.
optimizer
.
step
()
...
...
@@ -327,14 +415,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
# Successful update.
return
True
return
True
,
grad_norm
,
num_zeros_in_grad
def
state_dict
(
self
):
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'fp32_from_fp16_params'
]
=
self
.
fp32_from_fp16_groups
if
self
.
grad_scaler
:
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'fp32_from_fp16_params'
]
=
self
.
fp32_from_float16_groups
return
state_dict
...
...
@@ -352,15 +441,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
'load grad scaler ...'
)
else
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
if
self
.
grad_scaler
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
else
:
print_rank_0
(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
# Copy data for the main params.
fp32_from_f
p
16_params_key
=
'fp32_from_fp16_params'
if
fp32_from_f
p
16_params_key
not
in
state_dict
:
fp32_from_f
p
16_params_key
=
'fp32_from_fp16'
fp32_from_f
loat
16_params_key
=
'fp32_from_fp16_params'
if
fp32_from_f
loat
16_params_key
not
in
state_dict
:
fp32_from_f
loat
16_params_key
=
'fp32_from_fp16'
for
current_group
,
saved_group
in
zip
(
self
.
fp32_from_f
p
16_groups
,
state_dict
[
fp32_from_f
p
16_params_key
]):
self
.
fp32_from_f
loat
16_groups
,
state_dict
[
fp32_from_f
loat
16_params_key
]):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
...
...
@@ -368,10 +462,14 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
class
FP32Optimizer
(
MegatronOptimizer
):
def
__init__
(
self
,
optimizer
,
clip_grad
):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
):
super
(
FP32Optimizer
,
self
).
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
)
super
(
FP32Optimizer
,
self
).
__init__
(
optimizer
)
self
.
clip_grad
=
clip_grad
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
1.0
])
...
...
@@ -391,15 +489,26 @@ class FP32Optimizer(MegatronOptimizer):
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
# Copy main_grads to grads.
if
self
.
params_have_main_grad
:
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
param
.
grad
=
param
.
main_grad
# Clip gradients.
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
self
.
clip_grad_norm
(
self
.
clip_grad
)
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
# count the zeros in the grads
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
# Update parameters.
self
.
optimizer
.
step
()
# No overflow for FP32 optimizer.
return
True
return
True
,
grad_norm
,
num_zeros_in_grad
def
reload_model_params
(
self
):
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment