Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
db92ee13
Unverified
Commit
db92ee13
authored
Dec 14, 2021
by
Jithun Nair
Committed by
GitHub
Dec 14, 2021
Browse files
Merge pull request #64 from ROCmSoftwarePlatform/IFU-master-2021-12-08
IFU-master-2021-12-08
parents
d150afdc
68364b49
Changes
98
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1938 additions
and
94 deletions
+1938
-94
apex/transformer/tensor_parallel/layers.py
apex/transformer/tensor_parallel/layers.py
+12
-12
apex/transformer/tensor_parallel/mappings.py
apex/transformer/tensor_parallel/mappings.py
+5
-5
apex/transformer/tensor_parallel/memory.py
apex/transformer/tensor_parallel/memory.py
+1
-1
apex/transformer/tensor_parallel/random.py
apex/transformer/tensor_parallel/random.py
+10
-27
apex/transformer/tensor_parallel/utils.py
apex/transformer/tensor_parallel/utils.py
+5
-15
apex/transformer/testing/__init__.py
apex/transformer/testing/__init__.py
+0
-0
apex/transformer/testing/arguments.py
apex/transformer/testing/arguments.py
+56
-16
apex/transformer/testing/commons.py
apex/transformer/testing/commons.py
+38
-2
apex/transformer/testing/global_vars.py
apex/transformer/testing/global_vars.py
+19
-9
apex/transformer/testing/standalone_bert.py
apex/transformer/testing/standalone_bert.py
+217
-0
apex/transformer/testing/standalone_gpt.py
apex/transformer/testing/standalone_gpt.py
+1504
-0
apex/transformer/utils.py
apex/transformer/utils.py
+36
-0
csrc/amp_C_frontend.cpp
csrc/amp_C_frontend.cpp
+29
-0
csrc/megatron/scaled_masked_softmax.cpp
csrc/megatron/scaled_masked_softmax.cpp
+1
-1
csrc/megatron/scaled_masked_softmax.h
csrc/megatron/scaled_masked_softmax.h
+1
-1
csrc/megatron/scaled_masked_softmax_cuda.cu
csrc/megatron/scaled_masked_softmax_cuda.cu
+1
-1
csrc/megatron/scaled_upper_triang_masked_softmax.cpp
csrc/megatron/scaled_upper_triang_masked_softmax.cpp
+1
-1
csrc/megatron/scaled_upper_triang_masked_softmax.h
csrc/megatron/scaled_upper_triang_masked_softmax.h
+1
-1
csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu
csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu
+1
-1
csrc/multi_tensor_apply.cuh
csrc/multi_tensor_apply.cuh
+0
-1
No files found.
apex/transformer/tensor_parallel/layers.py
View file @
db92ee13
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -21,17 +21,17 @@ import torch.nn.functional as F
...
@@ -21,17 +21,17 @@ import torch.nn.functional as F
import
torch.nn.init
as
init
import
torch.nn.init
as
init
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
..
._autocast_utils
import
_cast_if_autocast_enabled
from
apex
._autocast_utils
import
_cast_if_autocast_enabled
from
.
.parallel_state
import
get_tensor_model_parallel_group
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_group
from
.
.parallel_state
import
get_tensor_model_parallel_rank
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_rank
from
.
.parallel_state
import
get_tensor_model_parallel_world_size
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_world_size
from
.mappings
import
copy_to_tensor_model_parallel_region
from
apex.transformer.utils
import
divide
from
.mappings
import
gather_from
_tensor_model_parallel_region
from
apex.transformer.tensor_parallel
.mappings
import
copy_to
_tensor_model_parallel_region
from
.mappings
import
reduce
_from_tensor_model_parallel_region
from
apex.transformer.tensor_parallel
.mappings
import
gather
_from_tensor_model_parallel_region
from
.mappings
import
scatter_to
_tensor_model_parallel_region
from
apex.transformer.tensor_parallel
.mappings
import
reduce_from
_tensor_model_parallel_region
from
.random
import
get_cuda_rng_tracker
from
apex.transformer.tensor_parallel.mappings
import
scatter_to_tensor_model_parallel_region
from
.utils
import
divide
from
apex.transformer.tensor_parallel.random
import
get_cuda_rng_tracker
from
.utils
import
VocabUtility
from
apex.transformer.tensor_parallel
.utils
import
VocabUtility
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
...
...
apex/transformer/tensor_parallel/mappings.py
View file @
db92ee13
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,10 +14,10 @@
...
@@ -14,10 +14,10 @@
# limitations under the License.
# limitations under the License.
import
torch
import
torch
from
.
.parallel_state
import
get_tensor_model_parallel_group
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_group
from
.
.parallel_state
import
get_tensor_model_parallel_world_size
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_world_size
from
.
.parallel_state
import
get_tensor_model_parallel_rank
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_rank
from
.utils
import
split_tensor_along_last_dim
from
apex.transformer.tensor_parallel
.utils
import
split_tensor_along_last_dim
def
_reduce
(
input_
):
def
_reduce
(
input_
):
...
...
apex/transformer/tensor_parallel/memory.py
View file @
db92ee13
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
apex/transformer/tensor_parallel/random.py
View file @
db92ee13
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -13,6 +13,9 @@
...
@@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# TODO (mkozuki): Audit this file.
# I don't think some functions strongly relate to `random` in tensor_parallel.
# Rather, some functions are mainly for gradient checkpointing (torch.utils.checkpoint).
# Parts of the code here are adapted from PyTorch
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
# repo: https://github.com/pytorch/pytorch
...
@@ -23,18 +26,17 @@ from torch import _C
...
@@ -23,18 +26,17 @@ from torch import _C
from
torch.cuda
import
_lazy_call
,
device
as
device_ctx_manager
from
torch.cuda
import
_lazy_call
,
device
as
device_ctx_manager
from
torch.utils.checkpoint
import
detach_variable
from
torch.utils.checkpoint
import
detach_variable
from
..parallel_state
import
get_data_parallel_rank
from
apex.transformer.parallel_state
import
get_tensor_model_parallel_rank
from
..parallel_state
import
get_tensor_model_parallel_group
from
apex.transformer.tensor_parallel.memory
import
allocate_mem_buff
from
..parallel_state
import
get_tensor_model_parallel_rank
from
apex.transformer.utils
import
split_tensor_into_1d_equal_chunks
from
..parallel_state
import
get_tensor_model_parallel_world_size
from
apex.transformer.utils
import
gather_split_1d_tensor
from
.memory
import
allocate_mem_buff
# Default name for the model parallel rng tracker.
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
"model-parallel-rng"
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
"model-parallel-rng"
# Whether apply model parallel
s
im to checkpointed hidden states.
# Whether apply model paralleli
s
m to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
None
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
None
...
@@ -108,26 +110,6 @@ def _set_cuda_rng_state(new_state, device=-1):
...
@@ -108,26 +110,6 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call
(
cb
)
_lazy_call
(
cb
)
def
split_tensor_into_1d_equal_chunks
(
tensor
):
"""Break a tensor into equal 1D chunks."""
data
=
tensor
.
view
(
-
1
)
partition_size
=
torch
.
numel
(
data
)
//
get_tensor_model_parallel_world_size
()
start_index
=
partition_size
*
get_tensor_model_parallel_rank
()
end_index
=
start_index
+
partition_size
return
data
[
start_index
:
end_index
]
def
gather_split_1d_tensor
(
tensor
):
"""Opposite of above function, gather values from model parallel ranks."""
world_size
=
get_tensor_model_parallel_world_size
()
numel
=
torch
.
numel
(
tensor
)
numel_gathered
=
world_size
*
numel
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
chunks
=
[
gathered
[
i
*
numel
:
(
i
+
1
)
*
numel
]
for
i
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
chunks
,
tensor
,
group
=
get_tensor_model_parallel_group
())
return
gathered
class
CudaRNGStatesTracker
:
class
CudaRNGStatesTracker
:
"""Tracker for the cuda RNG states.
"""Tracker for the cuda RNG states.
...
@@ -238,6 +220,7 @@ def model_parallel_cuda_manual_seed(seed):
...
@@ -238,6 +220,7 @@ def model_parallel_cuda_manual_seed(seed):
_CUDA_RNG_STATE_TRACKER
.
add
(
_MODEL_PARALLEL_RNG_TRACKER_NAME
,
tensor_model_parallel_seed
)
_CUDA_RNG_STATE_TRACKER
.
add
(
_MODEL_PARALLEL_RNG_TRACKER_NAME
,
tensor_model_parallel_seed
)
# TODO (mkozuki): Move the below gradient checkpoint related features to another (new) file.
class
CheckpointFunction
(
torch
.
autograd
.
Function
):
class
CheckpointFunction
(
torch
.
autograd
.
Function
):
"""This function is adapted from torch.utils.checkpoint with
"""This function is adapted from torch.utils.checkpoint with
two main changes:
two main changes:
...
...
apex/transformer/tensor_parallel/utils.py
View file @
db92ee13
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,17 +14,7 @@
...
@@ -14,17 +14,7 @@
# limitations under the License.
# limitations under the License.
import
torch
import
torch
from
apex.transformer.utils
import
divide
def
ensure_divisibility
(
numerator
,
denominator
):
"""Ensure that numerator is divisible by the denominator."""
assert
numerator
%
denominator
==
0
,
"{} is not divisible by {}"
.
format
(
numerator
,
denominator
)
def
divide
(
numerator
,
denominator
):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility
(
numerator
,
denominator
)
return
numerator
//
denominator
def
split_tensor_along_last_dim
(
tensor
,
num_partitions
,
contiguous_split_chunks
=
False
):
def
split_tensor_along_last_dim
(
tensor
,
num_partitions
,
contiguous_split_chunks
=
False
):
...
@@ -48,9 +38,9 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=
...
@@ -48,9 +38,9 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=
class
VocabUtility
:
class
VocabUtility
:
"""Split the vocabulary into `world_size` chunks a
m
d return the
"""Split the vocabulary into `world_size` chunks a
n
d return the
first and last index of the vocabulary belonging to the `rank`
first and last index of the vocabulary belonging to the `rank`
partition: Note that ind
ec
ies in [fist, last)"""
partition: Note that indi
c
es in [fist, last)"""
@
staticmethod
@
staticmethod
def
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
,
world_size
):
def
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
,
world_size
):
...
...
apex/transformer/te
nsor_parallel/tests
/__init__.py
→
apex/transformer/te
sting
/__init__.py
View file @
db92ee13
File moved
apex/transformer/te
nsor_parallel/tests
/arguments.py
→
apex/transformer/te
sting
/arguments.py
View file @
db92ee13
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,13 +12,14 @@
...
@@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Megatron arguments."""
"""Megatron arguments."""
import
argparse
import
argparse
import
os
import
os
import
torch
import
torch
def
parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
def
parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
ignore_unknown_args
=
False
):
ignore_unknown_args
=
False
):
"""Parse all arguments."""
"""Parse all arguments."""
...
@@ -79,6 +80,12 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -79,6 +80,12 @@ def parse_args(extra_args_provider=None, defaults={},
args
.
world_size
,
args
.
data_parallel_size
,
args
.
world_size
,
args
.
data_parallel_size
,
args
.
tensor_model_parallel_size
,
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
),
flush
=
True
)
args
.
pipeline_model_parallel_size
),
flush
=
True
)
if
args
.
pipeline_model_parallel_size
>
1
:
if
args
.
pipeline_model_parallel_split_rank
is
not
None
:
assert
args
.
pipeline_model_parallel_split_rank
<
\
args
.
pipeline_model_parallel_size
,
'split rank needs'
\
' to be less than pipeline model parallel size ({})'
.
format
(
args
.
pipeline_model_parallel_size
)
# Deprecated arguments
# Deprecated arguments
assert
args
.
batch_size
is
None
,
'--batch-size argument is no longer '
\
assert
args
.
batch_size
is
None
,
'--batch-size argument is no longer '
\
...
@@ -90,6 +97,13 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -90,6 +97,13 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
model_parallel_size
is
None
,
'--model-parallel-size is no '
\
assert
args
.
model_parallel_size
is
None
,
'--model-parallel-size is no '
\
'longer valid, use --tensor-model-parallel-size instead'
'longer valid, use --tensor-model-parallel-size instead'
del
args
.
model_parallel_size
del
args
.
model_parallel_size
if
args
.
checkpoint_activations
:
args
.
activations_checkpoint_method
=
'uniform'
if
args
.
rank
==
0
:
print
(
'--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. '
'Defaulting to activation-checkpoint-method=uniform.'
)
del
args
.
checkpoint_activations
# Set input defaults.
# Set input defaults.
for
key
in
defaults
:
for
key
in
defaults
:
...
@@ -147,16 +161,15 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -147,16 +161,15 @@ def parse_args(extra_args_provider=None, defaults={},
print
(
'using {} for parameters ...'
.
format
(
args
.
params_dtype
),
print
(
'using {} for parameters ...'
.
format
(
args
.
params_dtype
),
flush
=
True
)
flush
=
True
)
# If we do accumulation and all-reduces in fp32, we need to have
# If we do accumulation and all-reduces in fp32, we need to have
local DDP
#
local DDP
and we should
set th
e use-contiguous-buffers-in-
ddp
.
# and we should
make sur
e use-contiguous-buffers-in-
local-ddp is not off
.
if
args
.
accumulate_allreduce_grads_in_fp32
:
if
args
.
accumulate_allreduce_grads_in_fp32
:
assert
args
.
DDP_impl
==
'local'
assert
args
.
DDP_impl
==
'local'
args
.
use_contiguous_buffers_in_
ddp
=
True
assert
args
.
use_contiguous_buffers_in_
local_ddp
# If we use a contiguous buffer to hold main grads, we need to have
# For torch DDP, we do not use contiguous buffer
# local DDP.
if
args
.
DDP_impl
==
'torch'
:
if
args
.
use_contiguous_buffers_in_ddp
:
args
.
use_contiguous_buffers_in_local_ddp
=
False
assert
args
.
DDP_impl
==
'local'
if
args
.
dataloader_type
is
None
:
if
args
.
dataloader_type
is
None
:
args
.
dataloader_type
=
'single'
args
.
dataloader_type
=
'single'
...
@@ -233,9 +246,15 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -233,9 +246,15 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.'
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
# Activation checkpointing.
if
args
.
distribute_checkpointed_activations
:
if
args
.
distribute_checkpointed_activations
:
assert
args
.
checkpoint_activations
,
\
assert
args
.
tensor_model_parallel_size
>
1
,
'can distribute '
\
'checkpointed activations only across tensor model '
\
'parallel groups'
assert
args
.
activations_checkpoint_method
is
not
None
,
\
'for distribute-checkpointed-activations to work you '
\
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
'need to use a activation-checkpoint method '
assert
args
.
num_layers_per_virtual_pipeline_stage
is
None
,
\
'currently distrobuted checkpoint activations only supported for '
\
'nointerleaved pipeline parallelism'
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
@@ -401,8 +420,20 @@ def _add_training_args(parser):
...
@@ -401,8 +420,20 @@ def _add_training_args(parser):
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'If set, distribute checkpointed activations '
help
=
'If set, distribute checkpointed activations '
'across model parallel group.'
)
'across model parallel group.'
)
group
.
add_argument
(
'--checkpoint-num-layers'
,
type
=
int
,
default
=
1
,
group
.
add_argument
(
'--activations-checkpoint-method'
,
type
=
str
,
default
=
None
,
help
=
'chunk size (number of layers) for checkpointing.'
)
choices
=
[
'uniform'
,
'block'
],
help
=
'1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'2) checkpoint the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing'
'default) do not apply activations checkpoint to any layers'
)
group
.
add_argument
(
'--activations-checkpoint-num-layers'
,
type
=
int
,
default
=
1
,
help
=
'1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.'
)
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
help
=
'Total number of iterations to train over all '
help
=
'Total number of iterations to train over all '
'training runs. Note that either train-iters or '
'training runs. Note that either train-iters or '
...
@@ -437,6 +468,11 @@ def _add_training_args(parser):
...
@@ -437,6 +468,11 @@ def _add_training_args(parser):
group
.
add_argument
(
'--dataloader-type'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--dataloader-type'
,
type
=
str
,
default
=
None
,
choices
=
[
'single'
,
'cyclic'
],
choices
=
[
'single'
,
'cyclic'
],
help
=
'Single pass vs multiple pass data loader'
)
help
=
'Single pass vs multiple pass data loader'
)
group
.
add_argument
(
'--no-async-tensor-model-parallel-allreduce'
,
action
=
'store_true'
,
help
=
'Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.'
)
return
parser
return
parser
...
@@ -571,6 +607,9 @@ def _add_distributed_args(parser):
...
@@ -571,6 +607,9 @@ def _add_distributed_args(parser):
help
=
'Degree of tensor model parallelism.'
)
help
=
'Degree of tensor model parallelism.'
)
group
.
add_argument
(
'--pipeline-model-parallel-size'
,
type
=
int
,
default
=
1
,
group
.
add_argument
(
'--pipeline-model-parallel-size'
,
type
=
int
,
default
=
1
,
help
=
'Degree of pipeline model parallelism.'
)
help
=
'Degree of pipeline model parallelism.'
)
group
.
add_argument
(
'--pipeline-model-parallel-split-rank'
,
type
=
int
,
default
=
None
,
help
=
'Rank where encoder and decoder should be split.'
)
group
.
add_argument
(
'--model-parallel-size'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--model-parallel-size'
,
type
=
int
,
default
=
None
,
help
=
'Old model parallel argument, do not use. Use '
help
=
'Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.'
)
'--tensor-model-parallel-size instead.'
)
...
@@ -583,9 +622,10 @@ def _add_distributed_args(parser):
...
@@ -583,9 +622,10 @@ def _add_distributed_args(parser):
choices
=
[
'local'
,
'torch'
],
choices
=
[
'local'
,
'torch'
],
help
=
'which DistributedDataParallel implementation '
help
=
'which DistributedDataParallel implementation '
'to use.'
)
'to use.'
)
group
.
add_argument
(
'--use-contiguous-buffers-in-ddp'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-contiguous-buffers-in-local-ddp'
,
help
=
'If set, use contiguous buffer in DDP. Note that '
action
=
'store_false'
,
help
=
'If set, dont use '
'this option only works woth local DDP.'
)
'contiguous buffer in local DDP.'
,
dest
=
'use_contiguous_buffers_in_local_ddp'
)
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
dest
=
'scatter_gather_tensors_in_pipeline'
)
dest
=
'scatter_gather_tensors_in_pipeline'
)
...
...
apex/transformer/te
nsor_parallel/tests
/commons.py
→
apex/transformer/te
sting
/commons.py
View file @
db92ee13
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,17 +14,53 @@
...
@@ -14,17 +14,53 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
random
import
random
from
typing
import
Optional
,
Union
,
List
import
numpy
import
numpy
import
torch
import
torch
import
torch.nn
as
nn
from
apex
import
transformer
from
apex
import
transformer
from
apex.transformer.te
nsor_parallel.tests
import
global_vars
from
apex.transformer.te
sting
import
global_vars
TEST_SUCCESS_MESSAGE
=
">> passed the test :-)"
TEST_SUCCESS_MESSAGE
=
">> passed the test :-)"
# note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes.
class
MyLayer
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
pre_process
:
bool
,
post_process
:
bool
):
super
().
__init__
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
layer
=
nn
.
Linear
(
hidden_size
,
hidden_size
)
def
forward
(
self
,
x
):
return
self
.
layer
(
x
)
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
pre_process
:
bool
=
False
,
post_process
:
bool
=
False
)
->
None
:
super
().
__init__
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
layer
=
MyLayer
(
hidden_size
=
hidden_size
,
pre_process
=
pre_process
,
post_process
=
post_process
)
self
.
input_tensor
=
None
def
set_input_tensor
(
self
,
input_tensor
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]])
->
None
:
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
x
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
if
self
.
input_tensor
is
None
:
return
self
.
layer
(
x
)
return
self
.
layer
(
self
.
input_tensor
)
def
model_provider_func
(
hidden_size
,
pre_process
,
post_process
)
->
MyModel
:
return
MyModel
(
hidden_size
,
pre_process
,
post_process
)
class
IdentityLayer
(
torch
.
nn
.
Module
):
class
IdentityLayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
size
,
scale
=
1.0
):
def
__init__
(
self
,
size
,
scale
=
1.0
):
super
(
IdentityLayer
,
self
).
__init__
()
super
(
IdentityLayer
,
self
).
__init__
()
...
...
apex/transformer/te
nsor_parallel/tests
/global_vars.py
→
apex/transformer/te
sting
/global_vars.py
View file @
db92ee13
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -20,8 +20,8 @@ import time
...
@@ -20,8 +20,8 @@ import time
import
torch
import
torch
from
apex.transformer.
tensor_parallel.
microbatches
import
build_num_microbatches_calculator
from
apex.transformer.microbatches
import
build_num_microbatches_calculator
from
apex.transformer.tensor_parallel.tests
.arguments
import
parse_args
from
.arguments
import
parse_args
_GLOBAL_ARGS
=
None
_GLOBAL_ARGS
=
None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
...
@@ -37,17 +37,27 @@ def get_args():
...
@@ -37,17 +37,27 @@ def get_args():
return
_GLOBAL_ARGS
return
_GLOBAL_ARGS
def
get_num_microbatches
():
def
get_num_microbatches
()
->
int
:
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get
()
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get
()
def
get_current_global_batch_size
():
def
get_current_global_batch_size
()
->
int
:
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get_current_global_batch_size
()
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get_current_global_batch_size
()
def
update_num_microbatches
(
consumed_samples
,
consistency_check
=
True
):
def
update_num_microbatches
(
consumed_samples
:
int
,
*
,
consistency_check
:
bool
=
True
)
->
None
:
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
update
(
consumed_samples
,
"""Update the number of microbatches upon the number of consumed samples.
consistency_check
)
.. note::
This function has no effect unless ``rampup_batch_size`` is set.
Args:
consumed_samples: The number of consumed samples so far. Basically this is equal to
:math:`num_iter * global_batch_size`.
consistency_check: If :obj:`True`, sanity checks the consumed samples, i.e., check if
``consumed_samples`` is divisible by :math:`micro_batch_size
\t
imes data_parallel_size`.
"""
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
update
(
consumed_samples
,
consistency_check
)
# def get_tokenizer():
# def get_tokenizer():
...
@@ -80,7 +90,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
...
@@ -80,7 +90,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
args
=
_parse_args
(
extra_args_provider
=
extra_args_provider
,
args
=
_parse_args
(
extra_args_provider
=
extra_args_provider
,
defaults
=
args_defaults
,
defaults
=
args_defaults
,
ignore_unknown_args
=
ignore_unknown_args
)
ignore_unknown_args
=
ignore_unknown_args
)
_build_num_microbatches_calculator
(
args
)
#
_build_num_microbatches_calculator(args)
# if args.vocab_file:
# if args.vocab_file:
# _ = _build_tokenizer(args)
# _ = _build_tokenizer(args)
_set_tensorboard_writer
(
args
)
_set_tensorboard_writer
(
args
)
...
...
apex/transformer/testing/standalone_bert.py
0 → 100644
View file @
db92ee13
import
torch
from
apex.normalization
import
FusedLayerNorm
as
LayerNorm
from
apex.transformer
import
tensor_parallel
from
apex.transformer.enums
import
AttnMaskType
from
apex.transformer.testing.global_vars
import
get_args
from
.standalone_gpt
import
get_language_model
,
get_linear_layer
,
init_method_normal
,
parallel_lm_logits
,
scaled_init_method_normal
from
.standalone_gpt
import
MegatronModule
def
bert_extended_attention_mask
(
attention_mask
):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s
=
attention_mask
.
unsqueeze
(
1
)
# [b, s, 1]
attention_mask_bs1
=
attention_mask
.
unsqueeze
(
2
)
# [b, s, s]
attention_mask_bss
=
attention_mask_b1s
*
attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask
=
attention_mask_bss
.
unsqueeze
(
1
)
# Convert attention mask to binary:
extended_attention_mask
=
(
extended_attention_mask
<
0.5
)
return
extended_attention_mask
def
bert_position_ids
(
token_ids
):
# Create position ids
seq_length
=
token_ids
.
size
(
1
)
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
token_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
token_ids
)
return
position_ids
class
BertLMHead
(
MegatronModule
):
"""Masked LM head for Bert
Arguments:
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: whether output logits being distributed or not.
"""
def
__init__
(
self
,
mpu_vocab_size
,
hidden_size
,
init_method
,
layernorm_epsilon
,
parallel_output
):
super
(
BertLMHead
,
self
).
__init__
()
args
=
get_args
()
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
# TODO: do we need this?
# mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self
.
parallel_output
=
parallel_output
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
if
args
.
openai_gelu
:
self
.
gelu
=
openai_gelu
elif
args
.
onnx_safe
:
self
.
gelu
=
erf_gelu
def
forward
(
self
,
hidden_states
,
word_embeddings_weight
):
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
gelu
(
hidden_states
)
hidden_states
=
self
.
layernorm
(
hidden_states
)
output
=
parallel_lm_logits
(
hidden_states
,
word_embeddings_weight
,
self
.
parallel_output
,
bias
=
self
.
bias
)
return
output
def
post_language_model_processing
(
lm_output
,
pooled_output
,
lm_head
,
binary_head
,
lm_labels
,
logit_weights
,
fp16_lm_cross_entropy
):
# Output.
lm_logits
=
lm_head
(
lm_output
,
logit_weights
)
binary_logits
=
None
if
binary_head
is
not
None
:
binary_logits
=
binary_head
(
pooled_output
)
if
lm_labels
is
None
:
return
lm_logits
,
binary_logits
else
:
if
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
lm_loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
return
lm_loss
,
binary_logits
class
BertModel
(
MegatronModule
):
"""Bert Language model."""
def
__init__
(
self
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
super
(
BertModel
,
self
).
__init__
()
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
add_binary_head
=
add_binary_head
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
add_pooler
=
self
.
add_binary_head
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
initialize_word_embeddings
(
init_method_normal
)
if
self
.
post_process
:
self
.
lm_head
=
BertLMHead
(
self
.
word_embeddings_weight
().
size
(
0
),
args
.
hidden_size
,
init_method
,
args
.
layernorm_epsilon
,
parallel_output
)
self
.
_lm_head_key
=
'lm_head'
self
.
binary_head
=
None
if
self
.
add_binary_head
:
self
.
binary_head
=
get_linear_layer
(
args
.
hidden_size
,
2
,
init_method
)
self
.
_binary_head_key
=
'binary_head'
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
bert_model_input
,
attention_mask
,
tokentype_ids
=
None
,
lm_labels
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
input_ids
=
bert_model_input
position_ids
=
bert_position_ids
(
input_ids
)
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
if
self
.
post_process
and
self
.
add_binary_head
:
lm_output
,
pooled_output
=
lm_output
else
:
pooled_output
=
None
if
self
.
post_process
:
return
post_language_model_processing
(
lm_output
,
pooled_output
,
self
.
lm_head
,
self
.
binary_head
,
lm_labels
,
self
.
word_embeddings_weight
(),
self
.
fp16_lm_cross_entropy
)
else
:
return
lm_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
post_process
:
state_dict_
[
self
.
_lm_head_key
]
\
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
post_process
and
self
.
add_binary_head
:
state_dict_
[
self
.
_binary_head_key
]
\
=
self
.
binary_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
# Save word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
self
.
post_process
:
self
.
lm_head
.
load_state_dict
(
state_dict
[
self
.
_lm_head_key
],
strict
=
strict
)
if
self
.
post_process
and
self
.
add_binary_head
:
self
.
binary_head
.
load_state_dict
(
state_dict
[
self
.
_binary_head_key
],
strict
=
strict
)
# Load word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
:
self
.
word_embeddings
.
load_state_dict
(
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
def
bert_model_provider
(
pre_process
=
True
,
post_process
=
True
):
model
=
BertModel
(
num_tokentypes
=
0
,
add_binary_head
=
False
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
apex/transformer/testing/standalone_gpt.py
0 → 100644
View file @
db92ee13
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT-2 model."""
import
enum
import
math
import
torch
import
torch.nn.functional
as
F
import
apex.transformer.utils
from
apex.normalization
import
FusedLayerNorm
as
LayerNorm
from
apex.transformer.functional
import
FusedScaleMaskSoftmax
from
apex.transformer
import
tensor_parallel
from
apex.transformer
import
parallel_state
from
apex.transformer.testing.global_vars
import
get_args
from
apex.transformer.enums
import
LayerType
from
apex.transformer.enums
import
AttnType
from
apex.transformer.enums
import
AttnMaskType
_FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
_HALF_TYPES
=
(
torch
.
HalfTensor
,
torch
.
cuda
.
HalfTensor
)
_BF16_TYPES
=
(
torch
.
BFloat16Tensor
,
torch
.
cuda
.
BFloat16Tensor
)
class
ModelType
(
enum
.
Enum
):
encoder_or_decoder
=
1
encoder_and_decoder
=
2
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@
torch
.
jit
.
script
def
bias_gelu
(
bias
,
y
):
x
=
bias
+
y
return
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
torch
.
jit
.
script
def
bias_gelu_back
(
g
,
bias
,
y
):
x
=
bias
+
y
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
return
ff
*
g
class
MegatronModule
(
torch
.
nn
.
Module
):
"""Megatron specific extensions of torch Module with support
for pipelining."""
def
__init__
(
self
,
share_word_embeddings
=
True
):
super
(
MegatronModule
,
self
).
__init__
()
self
.
share_word_embeddings
=
share_word_embeddings
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
""
,
keep_vars
=
False
):
"""Use this function to override the state dict for
saving checkpoints."""
return
self
.
state_dict
(
destination
,
prefix
,
keep_vars
)
def
word_embeddings_weight
(
self
):
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
or
\
parallel_state
.
get_pipeline_model_parallel_world_size
()
==
1
:
return
self
.
language_model
.
embedding
.
word_embeddings
.
weight
else
:
if
not
self
.
share_word_embeddings
:
raise
Exception
(
"word_embeddings_weight() called for last "
"stage, but share_word_embeddings is false"
)
return
self
.
word_embeddings
.
weight
def
initialize_word_embeddings
(
self
,
init_method_normal
):
args
=
get_args
()
if
not
self
.
share_word_embeddings
:
raise
Exception
(
"initialize_word_embeddings() was called but "
"share_word_embeddings is false"
)
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. Nothing to do if we aren't
# using pipeline parallelism.
if
args
.
pipeline_model_parallel_size
==
1
:
return
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if
parallel_state
.
is_pipeline_last_stage
():
assert
not
parallel_state
.
is_pipeline_first_stage
()
self
.
_word_embeddings_for_head_key
=
"word_embeddings_for_head"
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
args
.
padded_vocab_size
,
args
.
hidden_size
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
use_cpu_initialization
=
args
.
use_cpu_initialization
)
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if
not
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
and
\
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
\
parallel_state
.
is_rank_in_embedding_group
():
self
.
language_model
.
embedding
.
zero_parameters
()
# Ensure that first and last stages have the same initial parameter
# values.
if
torch
.
distributed
.
is_initialized
():
if
parallel_state
.
is_rank_in_embedding_group
():
torch
.
distributed
.
all_reduce
(
self
.
word_embeddings_weight
().
data
,
group
=
parallel_state
.
get_embedding_group
())
# All-reduce other embeddings as well as necessary. The last stage
# does not have these other embeddings, so just create placeholder
# tensors of the right shape with all zeros.
# NOTE: We don't currently support T5 with the interleaved schedule.
if
args
.
pipeline_model_parallel_split_rank
is
not
None
:
# TODO: Support tokentype embedding.
dimensions
=
(
args
.
max_position_embeddings
,
args
.
hidden_size
)
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
position_embeddings
=
torch
.
nn
.
Embedding
(
*
dimensions
).
cuda
()
position_embeddings
.
weight
.
data
.
fill_
(
0
)
else
:
self
.
language_model
.
embedding
.
cuda
()
position_embeddings
=
self
.
language_model
.
embedding
.
position_embeddings
torch
.
distributed
.
all_reduce
(
position_embeddings
.
weight
.
data
,
group
=
parallel_state
.
get_embedding_group
())
else
:
print
(
"WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
class
GeLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
bias
):
ctx
.
save_for_backward
(
input
,
bias
)
return
bias_gelu
(
bias
,
input
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
bias
=
ctx
.
saved_tensors
tmp
=
bias_gelu_back
(
grad_output
,
bias
,
input
)
return
tmp
,
tmp
bias_gelu_impl
=
GeLUFunction
.
apply
def
get_linear_layer
(
rows
,
columns
,
init_method
):
"""Simple linear layer with weight initialization."""
layer
=
torch
.
nn
.
Linear
(
rows
,
columns
)
init_method
(
layer
.
weight
)
with
torch
.
no_grad
():
layer
.
bias
.
zero_
()
return
layer
def
attention_mask_func
(
attention_scores
,
attention_mask
):
attention_scores
.
masked_fill_
(
attention_mask
,
-
10000.0
)
return
attention_scores
@
torch
.
jit
.
script
def
gelu_impl
(
x
):
"""OpenAI's gelu implementation."""
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
0.7978845608028654
*
x
*
(
1.0
+
0.044715
*
x
*
x
)))
def
openai_gelu
(
x
):
return
gelu_impl
(
x
)
# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@
torch
.
jit
.
script
def
erf_gelu
(
x
):
return
x
*
0.5
*
(
torch
.
erf
(
x
/
1.41421
).
to
(
dtype
=
x
.
dtype
)
+
torch
.
ones_like
(
x
).
to
(
dtype
=
x
.
dtype
))
def
init_method_normal
(
sigma
):
"""Init method based on N(0, sigma)."""
def
init_
(
tensor
):
return
torch
.
nn
.
init
.
normal_
(
tensor
,
mean
=
0.0
,
std
=
sigma
)
return
init_
def
scaled_init_method_normal
(
sigma
,
num_layers
):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std
=
sigma
/
math
.
sqrt
(
2.0
*
num_layers
)
def
init_
(
tensor
):
return
torch
.
nn
.
init
.
normal_
(
tensor
,
mean
=
0.0
,
std
=
std
)
return
init_
class
ParallelMLP
(
MegatronModule
):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
super
().
__init__
()
args
=
get_args
()
# Project to 4h.
self
.
dense_h_to_4h
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
ffn_hidden_size
,
gather_output
=
False
,
init_method
=
init_method
,
skip_bias_add
=
True
,
use_cpu_initialization
=
args
.
use_cpu_initialization
)
self
.
bias_gelu_fusion
=
args
.
bias_gelu_fusion
self
.
activation_func
=
F
.
gelu
if
args
.
openai_gelu
:
self
.
activation_func
=
openai_gelu
elif
args
.
onnx_safe
:
self
.
activation_func
=
erf_gelu
# Project back to h.
self
.
dense_4h_to_h
=
tensor_parallel
.
RowParallelLinear
(
args
.
ffn_hidden_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
,
use_cpu_initialization
=
args
.
use_cpu_initialization
)
def
forward
(
self
,
hidden_states
):
# [s, b, 4hp]
intermediate_parallel
,
bias_parallel
=
self
.
dense_h_to_4h
(
hidden_states
)
if
self
.
bias_gelu_fusion
:
intermediate_parallel
=
bias_gelu_impl
(
intermediate_parallel
,
bias_parallel
)
else
:
intermediate_parallel
=
self
.
activation_func
(
intermediate_parallel
+
bias_parallel
)
# [s, b, h]
output
,
output_bias
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
return
output
,
output_bias
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
AttnMaskType
.
padding
,
):
super
().
__init__
()
args
=
get_args
()
self
.
fp16
=
args
.
fp16
self
.
bf16
=
args
.
bf16
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
self
.
params_dtype
=
args
.
params_dtype
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
# Per attention head and per partition values.
world_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_partition
=
apex
.
transformer
.
utils
.
divide
(
projection_size
,
world_size
)
self
.
hidden_size_per_attention_head
=
apex
.
transformer
.
utils
.
divide
(
projection_size
,
args
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
apex
.
transformer
.
utils
.
divide
(
args
.
num_attention_heads
,
world_size
)
# Strided linear layer.
if
attention_type
==
AttnType
.
self_attn
:
self
.
query_key_value
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
3
*
projection_size
,
gather_output
=
False
,
init_method
=
init_method
,
use_cpu_initialization
=
args
.
use_cpu_initialization
)
else
:
assert
attention_type
==
AttnType
.
cross_attn
self
.
query
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
projection_size
,
gather_output
=
False
,
init_method
=
init_method
,
use_cpu_initialization
=
args
.
use_cpu_initialization
)
self
.
key_value
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
2
*
projection_size
,
gather_output
=
False
,
init_method
=
init_method
,
use_cpu_initialization
=
args
.
use_cpu_initialization
)
coeff
=
None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
if
self
.
apply_query_key_layer_scaling
:
coeff
=
self
.
layer_number
self
.
norm_factor
*=
coeff
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
self
.
bf16
,
self
.
attn_mask_type
,
args
.
masked_softmax_fusion
,
attention_mask_func
,
self
.
attention_softmax_in_fp32
,
coeff
,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self
.
attention_dropout
=
torch
.
nn
.
Dropout
(
args
.
attention_dropout
)
# Output.
self
.
dense
=
tensor_parallel
.
RowParallelLinear
(
projection_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
,
use_cpu_initialization
=
args
.
use_cpu_initialization
)
# Inference key-value memory
self
.
inference_key_memory
=
None
self
.
inference_value_memory
=
None
self
.
inference_current_sequence_len
=
0
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
):
return
torch
.
empty
(
inference_max_sequence_len
,
batch_size
,
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
dtype
=
self
.
params_dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
,
):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if
set_inference_key_value_memory
:
assert
inference_max_sequence_len
and
inference_max_sequence_len
>
0
self
.
inference_key_memory
=
self
.
_allocate_memory
(
inference_max_sequence_len
,
hidden_states
.
size
(
1
))
self
.
inference_value_memory
=
self
.
_allocate_memory
(
inference_max_sequence_len
,
hidden_states
.
size
(
1
))
self
.
inference_current_sequence_len
=
0
# Some consistency check.
if
inference_max_sequence_len
:
assert
self
.
inference_current_sequence_len
<
self
.
inference_key_memory
.
size
(
0
)
assert
inference_max_sequence_len
==
self
.
inference_key_memory
.
size
(
0
)
# This is added for safety. In case inference_max_sequence_len
# is not provided, make sure there is no potential memory left
# from previous inference.
if
not
inference_max_sequence_len
:
self
.
inference_key_memory
=
None
self
.
inference_value_memory
=
None
# =====================
# Query, Key, and Value
# =====================
if
self
.
attention_type
==
AttnType
.
self_attn
:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads_per_partition
,
3
*
self
.
hidden_size_per_attention_head
,
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(
query_layer
,
key_layer
,
value_layer
)
=
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
else
:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer
,
_
=
self
.
key_value
(
encoder_output
)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape
=
mixed_kv_layer
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads_per_partition
,
2
*
self
.
hidden_size_per_attention_head
,
)
mixed_kv_layer
=
mixed_kv_layer
.
view
(
*
new_tensor_shape
)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(
key_layer
,
value_layer
)
=
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer
,
_
=
self
.
query
(
hidden_states
)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape
=
query_layer
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ===================================================
# Adjust key, value, and attention mask for inference
# ===================================================
if
inference_max_sequence_len
:
# Adjust the range variables.
start
=
self
.
inference_current_sequence_len
self
.
inference_current_sequence_len
+=
key_layer
.
size
(
0
)
end
=
self
.
inference_current_sequence_len
# Copy key and values.
self
.
inference_key_memory
[
start
:
end
,
...]
=
key_layer
self
.
inference_value_memory
[
start
:
end
,
...]
=
value_layer
key_layer
=
self
.
inference_key_memory
[:
end
,
...]
value_layer
=
self
.
inference_value_memory
[:
end
,
...]
# Adjust attention mask
attention_mask
=
attention_mask
[...,
start
:
end
,
:
end
]
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting result tensor: [b * np, sq, sk]
matmul_result
=
torch
.
empty
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
],
dtype
=
query_layer
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
),
)
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs
=
self
.
scale_mask_softmax
(
attention_scores
,
attention_mask
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
))
# change view [sk, b * np, hn]
value_layer
=
value_layer
.
view
(
value_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# change view [b * np, sq, sk]
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
# matmul: [b * np, sq, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
# change view [b, np, sq, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
# =================
# Output. [sq, b, h]
# =================
output
,
bias
=
self
.
dense
(
context_layer
)
return
output
,
bias
@
torch
.
jit
.
script
def
bias_dropout_add
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
,
training
:
bool
)
->
torch
.
Tensor
:
out
=
torch
.
nn
.
functional
.
dropout
(
x
+
bias
,
p
=
prob
,
training
=
training
)
out
=
residual
+
out
return
out
def
get_bias_dropout_add
(
training
):
def
_bias_dropout_add
(
x
,
bias
,
residual
,
prob
):
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
)
return
_bias_dropout_add
@
torch
.
jit
.
script
def
bias_dropout_add_fused_train
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
True
)
@
torch
.
jit
.
script
def
bias_dropout_add_fused_inference
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
False
)
class
ParallelTransformerLayer
(
MegatronModule
):
"""A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
):
args
=
get_args
()
super
().
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_layernorm
=
args
.
apply_residual_connection_post_layernorm
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
# Self attention.
self
.
self_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
,
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
cross_attn
)
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
# MLP
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
,
):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
,
attention_bias
=
self
.
self_attention
(
layernorm_output
,
attention_mask
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
,
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
if
self
.
layer_type
==
LayerType
.
decoder
:
attention_output
,
attention_bias
=
self
.
inter_attention
(
layernorm_output
,
enc_dec_attn_mask
,
encoder_output
=
encoder_output
)
# residual connection
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
return
output
class
ParallelTransformer
(
MegatronModule
):
"""Transformer class."""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
True
,
post_process
=
True
,
):
super
().
__init__
()
args
=
get_args
()
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
# Store activation checkpointing flag.
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
self
.
activations_checkpoint_num_layers
=
args
.
activations_checkpoint_num_layers
self
.
distribute_checkpointed_activations
=
args
.
distribute_checkpointed_activations
num_layers
=
args
.
num_layers
# Number of layers.
assert
(
num_layers
%
parallel_state
.
get_pipeline_model_parallel_world_size
()
==
0
),
"num_layers must be divisible by pipeline_model_parallel_size"
self
.
num_layers
=
num_layers
//
parallel_state
.
get_pipeline_model_parallel_world_size
()
# Transformer layers.
def
build_layer
(
layer_number
):
return
ParallelTransformerLayer
(
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
,
)
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
(
"num_layers_per_stage must be divisible by "
"virtual_pipeline_model_parallel_size"
)
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset
=
parallel_state
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
(
parallel_state
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
# Each stage gets a contiguous set of layers.
offset
=
parallel_state
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
([
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
if
self
.
post_process
:
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
):
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
):
def
custom_forward
(
*
inputs
):
x_
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
encoder_output
=
inputs
[
2
]
enc_dec_attn_mask
=
inputs
[
3
]
for
index
in
range
(
start
,
end
):
layer
=
self
.
_get_layer
(
index
)
x_
=
layer
(
x_
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
return
x_
return
custom_forward
def
distribute_checkpointed_activations_helper
(
layer_number
):
"""Distribute checkpointed activations across the tensor model
Parallel ranks if the `distribute-checkpointed-activations
is on and either of the following conditions is met:
- it is not the first layer in the in the pipeline stage.
The first layer is used in the pipeline parallelism
and changing its shape throws error in the backward pass.
- we are at the first pipline stage so the input tensor is
not used in pipeline parallelism. Note that no pipeline
parallelism is a special case of this.
"""
not_first_layer_in_pipeline_stage
=
layer_number
>
0
is_first_pipeline_stage
=
parallel_state
.
get_pipeline_model_parallel_rank
()
==
0
return
self
.
distribute_checkpointed_activations
and
(
not_first_layer_in_pipeline_stage
or
is_first_pipeline_stage
)
if
self
.
activations_checkpoint_method
==
"uniform"
:
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l
=
0
while
l
<
self
.
num_layers
:
hidden_states
=
tensor_parallel
.
checkpoint
(
custom
(
l
,
l
+
self
.
activations_checkpoint_num_layers
),
distribute_checkpointed_activations_helper
(
l
),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
)
l
+=
self
.
activations_checkpoint_num_layers
elif
self
.
activations_checkpoint_method
==
"block"
:
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
activations_checkpoint_num_layers
:
hidden_states
=
tensor_parallel
.
checkpoint
(
custom
(
l
,
l
+
1
),
distribute_checkpointed_activations_helper
(
l
),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
)
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
raise
ValueError
(
"Invalid activation checkpoint method."
)
return
hidden_states
def
set_input_tensor
(
self
,
input_tensor
):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
,
):
# Checks.
if
inference_max_sequence_len
:
assert
self
.
activations_checkpoint_method
is
None
,
"inference does not work with activation checkpointing"
if
self
.
pre_process
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
fp32_residual_connection
:
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
().
float
()
# Otherwise, leave it as is.
else
:
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
else
:
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
activations_checkpoint_method
is
not
None
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
,
)
# Final layer norm.
if
self
.
post_process
:
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
self
.
final_layernorm
(
hidden_states
)
else
:
output
=
hidden_states
return
output
def
parallel_lm_logits
(
input_
,
word_embeddings_weight
,
parallel_output
,
bias
=
None
):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel
=
tensor_parallel
.
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
if
bias
is
None
:
logits_parallel
=
F
.
linear
(
input_parallel
,
word_embeddings_weight
)
else
:
logits_parallel
=
F
.
linear
(
input_parallel
,
word_embeddings_weight
,
bias
)
# Gather if needed.
if
parallel_output
:
return
logits_parallel
return
tensor_parallel
.
gather_from_tensor_model_parallel_region
(
logits_parallel
)
def
get_language_model
(
num_tokentypes
,
add_pooler
,
encoder_attn_mask_type
,
init_method
=
None
,
scaled_init_method
=
None
,
add_encoder
=
True
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
pre_process
=
True
,
post_process
=
True
,
):
"""Build language model and return along with the key to save."""
args
=
get_args
()
if
init_method
is
None
:
init_method
=
init_method_normal
(
args
.
init_method_std
)
if
scaled_init_method
is
None
:
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
# Language model.
language_model
=
TransformerLanguageModel
(
init_method
,
scaled_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
num_tokentypes
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
,
decoder_attn_mask_type
=
decoder_attn_mask_type
,
add_pooler
=
add_pooler
,
pre_process
=
pre_process
,
post_process
=
post_process
,
)
# key used for checkpoints.
language_model_key
=
"language_model"
return
language_model
,
language_model_key
class
Pooler
(
MegatronModule
):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def
__init__
(
self
,
hidden_size
,
init_method
):
super
(
Pooler
,
self
).
__init__
()
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
def
forward
(
self
,
hidden_states
,
sequence_index
=
0
):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
pooled
=
hidden_states
[:,
sequence_index
,
:]
pooled
=
self
.
dense
(
pooled
)
pooled
=
torch
.
tanh
(
pooled
)
return
pooled
class
Embedding
(
MegatronModule
):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def
__init__
(
self
,
hidden_size
,
vocab_size
,
max_sequence_length
,
embedding_dropout_prob
,
init_method
,
num_tokentypes
=
0
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
init_method
=
init_method
self
.
num_tokentypes
=
num_tokentypes
args
=
get_args
()
# Word embeddings (parallel).
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
vocab_size
,
self
.
hidden_size
,
init_method
=
self
.
init_method
,
use_cpu_initialization
=
args
.
use_cpu_initialization
)
self
.
_word_embeddings_key
=
"word_embeddings"
# Position embedding (serial).
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
max_sequence_length
,
self
.
hidden_size
)
self
.
_position_embeddings_key
=
"position_embeddings"
# Initialize the position embeddings.
self
.
init_method
(
self
.
position_embeddings
.
weight
)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self
.
_tokentype_embeddings_key
=
"tokentype_embeddings"
if
self
.
num_tokentypes
>
0
:
self
.
tokentype_embeddings
=
torch
.
nn
.
Embedding
(
self
.
num_tokentypes
,
self
.
hidden_size
)
# Initialize the token-type embeddings.
self
.
init_method
(
self
.
tokentype_embeddings
.
weight
)
else
:
self
.
tokentype_embeddings
=
None
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
print
(
"FINISH WORD EMBEDDING"
,
self
.
word_embeddings
)
def
zero_parameters
(
self
):
"""Zero out all parameters in embedding."""
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
self
.
position_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
position_embeddings
.
weight
.
shared
=
True
if
self
.
num_tokentypes
>
0
:
self
.
tokentype_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
tokentype_embeddings
.
weight
.
shared
=
True
def
add_tokentype_embeddings
(
self
,
num_tokentypes
):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
This allows us to load the model normally and then add this embedding.
"""
if
self
.
tokentype_embeddings
is
not
None
:
raise
Exception
(
"tokentype embeddings is already initialized"
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"adding embedding for {} tokentypes"
.
format
(
num_tokentypes
),
flush
=
True
)
self
.
num_tokentypes
=
num_tokentypes
self
.
tokentype_embeddings
=
torch
.
nn
.
Embedding
(
num_tokentypes
,
self
.
hidden_size
)
# Initialize the token-type embeddings.
self
.
init_method
(
self
.
tokentype_embeddings
.
weight
)
def
forward
(
self
,
input_ids
,
position_ids
,
tokentype_ids
=
None
):
# Embeddings.
words_embeddings
=
self
.
word_embeddings
(
input_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
words_embeddings
+
position_embeddings
if
tokentype_ids
is
not
None
:
assert
self
.
tokentype_embeddings
is
not
None
embeddings
=
embeddings
+
self
.
tokentype_embeddings
(
tokentype_ids
)
else
:
assert
self
.
tokentype_embeddings
is
None
# Dropout.
embeddings
=
self
.
embedding_dropout
(
embeddings
)
return
embeddings
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
""
,
keep_vars
=
False
):
"""For easy load."""
state_dict_
=
{}
state_dict_
[
self
.
_word_embeddings_key
]
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_position_embeddings_key
]
=
self
.
position_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
if
self
.
num_tokentypes
>
0
:
state_dict_
[
self
.
_tokentype_embeddings_key
]
=
self
.
tokentype_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
# Word embedding.
if
self
.
_word_embeddings_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_word_embeddings_key
]
else
:
# for backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
"word_embeddings"
in
key
:
state_dict_
[
key
.
split
(
"word_embeddings."
)[
1
]]
=
state_dict
[
key
]
self
.
word_embeddings
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Position embedding.
if
self
.
_position_embeddings_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_position_embeddings_key
]
else
:
# for backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
"position_embeddings"
in
key
:
state_dict_
[
key
.
split
(
"position_embeddings."
)[
1
]]
=
state_dict
[
key
]
self
.
position_embeddings
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Tokentype embedding.
if
self
.
num_tokentypes
>
0
:
state_dict_
=
{}
if
self
.
_tokentype_embeddings_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_tokentype_embeddings_key
]
else
:
# for backward compatibility.
for
key
in
state_dict
.
keys
():
if
"tokentype_embeddings"
in
key
:
state_dict_
[
key
.
split
(
"tokentype_embeddings."
)[
1
]]
=
state_dict
[
key
]
if
len
(
state_dict_
.
keys
())
>
0
:
self
.
tokentype_embeddings
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
else
:
print
(
"***WARNING*** expected tokentype embeddings in the "
"checkpoint but could not find it"
,
flush
=
True
)
class
TransformerLanguageModel
(
MegatronModule
):
"""Transformer language model.
Arguments:
transformer_hparams: transformer hyperparameters
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
0
,
add_encoder
=
True
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
add_pooler
=
False
,
pre_process
=
True
,
post_process
=
True
,
):
super
().
__init__
()
args
=
get_args
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
hidden_size
=
args
.
hidden_size
self
.
num_tokentypes
=
num_tokentypes
self
.
init_method
=
init_method
self
.
add_encoder
=
add_encoder
self
.
encoder_attn_mask_type
=
encoder_attn_mask_type
self
.
add_decoder
=
add_decoder
self
.
decoder_attn_mask_type
=
decoder_attn_mask_type
self
.
add_pooler
=
add_pooler
self
.
encoder_hidden_state
=
None
# Embeddings.
if
self
.
pre_process
:
self
.
embedding
=
Embedding
(
self
.
hidden_size
,
args
.
padded_vocab_size
,
args
.
max_position_embeddings
,
args
.
hidden_dropout
,
self
.
init_method
,
self
.
num_tokentypes
,
)
self
.
_embedding_key
=
"embedding"
# Transformer.
# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
if
self
.
add_encoder
:
self
.
encoder
=
ParallelTransformer
(
self
.
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
self
.
encoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
)
self
.
_encoder_key
=
"encoder"
else
:
self
.
encoder
=
None
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if
self
.
add_decoder
:
# Temporary assertion until we verify correctness of pipeline parallelism
# implementation of T5.
assert
(
args
.
pipeline_model_parallel_size
==
1
),
"pipeline parallelism is not supported in the presence of decoder"
self
.
decoder
=
ParallelTransformer
(
self
.
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
decoder
,
self_attn_mask_type
=
self
.
decoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
)
self
.
_decoder_key
=
"decoder"
else
:
self
.
decoder
=
None
if
self
.
post_process
:
# Pooler.
if
self
.
add_pooler
:
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
"pooler"
def
set_input_tensor
(
self
,
input_tensor
):
""" See megatron.model.transformer.set_input_tensor()"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
if
self
.
add_encoder
and
self
.
add_decoder
:
assert
(
len
(
input_tensor
)
==
1
),
"input_tensor should only be length 1 for stage with both encoder and decoder"
self
.
encoder
.
set_input_tensor
(
input_tensor
[
0
])
elif
self
.
add_encoder
:
assert
len
(
input_tensor
)
==
1
,
"input_tensor should only be length 1 for stage with only encoder"
self
.
encoder
.
set_input_tensor
(
input_tensor
[
0
])
elif
self
.
add_decoder
:
if
len
(
input_tensor
)
==
2
:
self
.
decoder
.
set_input_tensor
(
input_tensor
[
0
])
self
.
encoder_hidden_state
=
input_tensor
[
1
]
elif
len
(
input_tensor
)
==
1
:
self
.
decoder
.
set_input_tensor
(
None
)
self
.
encoder_hidden_state
=
input_tensor
[
0
]
else
:
raise
Exception
(
"input_tensor must have either length 1 or 2"
)
else
:
raise
Exception
(
"Stage must have at least either encoder or decoder"
)
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
,
):
# Encoder embedding.
if
self
.
pre_process
:
encoder_input
=
self
.
embedding
(
enc_input_ids
,
enc_position_ids
,
tokentype_ids
=
tokentype_ids
)
else
:
encoder_input
=
None
# Run encoder.
if
enc_hidden_states
is
None
:
if
self
.
encoder
is
not
None
:
encoder_output
=
self
.
encoder
(
encoder_input
,
enc_attn_mask
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
,
)
else
:
encoder_output
=
self
.
encoder_hidden_state
else
:
encoder_output
=
enc_hidden_states
.
to
(
encoder_input
.
dtype
)
if
self
.
post_process
:
if
self
.
add_pooler
:
pooled_output
=
self
.
pooler
(
encoder_output
,
pooling_sequence_index
)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if
not
self
.
add_decoder
or
output_enc_hidden
:
if
self
.
add_pooler
and
self
.
post_process
:
return
encoder_output
,
pooled_output
else
:
return
encoder_output
# Decoder embedding.
if
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
dec_input_ids
,
dec_position_ids
)
else
:
decoder_input
=
None
# Run decoder.
decoder_output
=
self
.
decoder
(
decoder_input
,
dec_attn_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
,
)
if
self
.
add_pooler
and
self
.
post_process
:
return
decoder_output
,
encoder_output
,
pooled_output
else
:
return
decoder_output
,
encoder_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
""
,
keep_vars
=
False
):
"""For easy load."""
state_dict_
=
{}
if
self
.
pre_process
:
state_dict_
[
self
.
_embedding_key
]
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
add_encoder
:
state_dict_
[
self
.
_encoder_key
]
=
self
.
encoder
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
post_process
:
if
self
.
add_pooler
:
state_dict_
[
self
.
_pooler_key
]
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
add_decoder
:
state_dict_
[
self
.
_decoder_key
]
=
self
.
decoder
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
# Embedding.
if
self
.
pre_process
:
if
self
.
_embedding_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_embedding_key
]
else
:
# for backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
"_embeddings"
in
key
:
state_dict_
[
key
]
=
state_dict
[
key
]
self
.
embedding
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Encoder.
if
self
.
add_encoder
:
if
self
.
_encoder_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_encoder_key
]
# For backward compatibility.
elif
"transformer"
in
state_dict
:
state_dict_
=
state_dict
[
"transformer"
]
else
:
# For backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
"transformer."
in
key
:
state_dict_
[
key
.
split
(
"transformer."
)[
1
]]
=
state_dict
[
key
]
# For backward compatibility.
state_dict_self_attention
=
{}
for
key
in
state_dict_
.
keys
():
if
".attention."
in
key
:
state_dict_self_attention
[
key
.
replace
(
".attention."
,
".self_attention."
)]
=
state_dict_
[
key
]
else
:
state_dict_self_attention
[
key
]
=
state_dict_
[
key
]
state_dict_
=
state_dict_self_attention
self
.
encoder
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Pooler.
if
self
.
post_process
:
if
self
.
add_pooler
:
assert
"pooler"
in
state_dict
,
"could not find data for pooler in the checkpoint"
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
strict
=
strict
)
# Decoder.
if
self
.
add_decoder
:
assert
"decoder"
in
state_dict
,
"could not find data for pooler in the checkpoint"
self
.
decoder
.
load_state_dict
(
state_dict
[
self
.
_decoder_key
],
strict
=
strict
)
def
post_language_model_processing
(
lm_output
,
labels
,
logit_weights
,
parallel_output
,
fp16_lm_cross_entropy
):
# Output.
output
=
parallel_lm_logits
(
lm_output
,
logit_weights
,
parallel_output
)
if
labels
is
None
:
return
output
else
:
if
fp16_lm_cross_entropy
:
assert
output
.
dtype
==
torch
.
half
loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
output
,
labels
)
else
:
loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
output
.
float
(),
labels
)
return
loss
class
GPTModel
(
MegatronModule
):
"""GPT-2 Language model."""
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
super
(
GPTModel
,
self
).
__init__
()
args
=
get_args
()
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
encoder_attn_mask_type
=
AttnMaskType
.
causal
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
)
self
.
initialize_word_embeddings
(
init_method_normal
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
tokentype_ids
=
None
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
,
):
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
attention_mask
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
,
)
if
self
.
post_process
:
return
post_language_model_processing
(
lm_output
,
labels
,
self
.
word_embeddings_weight
(),
self
.
parallel_output
,
self
.
fp16_lm_cross_entropy
)
else
:
return
lm_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
""
,
keep_vars
=
False
):
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
# Save word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
# Load word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
:
self
.
word_embeddings
.
load_state_dict
(
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
if
self
.
_language_model_key
in
state_dict
:
state_dict
=
state_dict
[
self
.
_language_model_key
]
self
.
language_model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
def
gpt_model_provider
(
pre_process
=
True
,
post_process
=
True
):
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
apex/transformer/utils.py
0 → 100644
View file @
db92ee13
"""Utility functions used by both `pipeline_parallel` and `tensor_parallel`"""
import
torch
from
apex.transformer
import
parallel_state
def
ensure_divisibility
(
numerator
,
denominator
):
"""Ensure that numerator is divisible by the denominator."""
assert
numerator
%
denominator
==
0
,
"{} is not divisible by {}"
.
format
(
numerator
,
denominator
)
def
divide
(
numerator
,
denominator
):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility
(
numerator
,
denominator
)
return
numerator
//
denominator
def
split_tensor_into_1d_equal_chunks
(
tensor
):
"""Break a tensor into equal 1D chunks."""
data
=
tensor
.
view
(
-
1
)
partition_size
=
torch
.
numel
(
data
)
//
parallel_state
.
get_tensor_model_parallel_world_size
()
start_index
=
partition_size
*
parallel_state
.
get_tensor_model_parallel_rank
()
end_index
=
start_index
+
partition_size
return
data
[
start_index
:
end_index
]
def
gather_split_1d_tensor
(
tensor
):
"""Opposite of above function, gather values from model parallel ranks."""
world_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
numel
=
torch
.
numel
(
tensor
)
numel_gathered
=
world_size
*
numel
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
chunks
=
[
gathered
[
i
*
numel
:(
i
+
1
)
*
numel
]
for
i
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
chunks
,
tensor
,
group
=
parallel_state
.
get_tensor_model_parallel_group
())
return
gathered
csrc/amp_C_frontend.cpp
View file @
db92ee13
...
@@ -33,6 +33,12 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
...
@@ -33,6 +33,12 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
optional
<
bool
>
per_tensor_python
);
at
::
optional
<
bool
>
per_tensor_python
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
multi_tensor_l2norm_mp_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
optional
<
bool
>
per_tensor_python
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
multi_tensor_l2norm_scale_cuda
(
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
multi_tensor_l2norm_scale_cuda
(
int
chunk_size
,
int
chunk_size
,
at
::
Tensor
noop_flag
,
at
::
Tensor
noop_flag
,
...
@@ -119,6 +125,25 @@ void multi_tensor_lamb_cuda(
...
@@ -119,6 +125,25 @@ void multi_tensor_lamb_cuda(
const
float
max_grad_norm
,
const
float
max_grad_norm
,
at
::
optional
<
bool
>
use_nvlamb_python
);
at
::
optional
<
bool
>
use_nvlamb_python
);
void
multi_tensor_lamb_mp_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
at
::
Tensor
step
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
mode
,
at
::
Tensor
global_grad_norm
,
at
::
Tensor
max_grad_norm
,
at
::
optional
<
bool
>
use_nvlamb_python
,
at
::
Tensor
found_inf
,
at
::
Tensor
inv_scale
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"multi_tensor_scale"
,
&
multi_tensor_scale_cuda
,
m
.
def
(
"multi_tensor_scale"
,
&
multi_tensor_scale_cuda
,
"Fused overflow check + scale for a list of contiguous tensors"
);
"Fused overflow check + scale for a list of contiguous tensors"
);
...
@@ -128,6 +153,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -128,6 +153,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"out = a*x + b*y for a list of contiguous tensors"
);
"out = a*x + b*y for a list of contiguous tensors"
);
m
.
def
(
"multi_tensor_l2norm"
,
&
multi_tensor_l2norm_cuda
,
m
.
def
(
"multi_tensor_l2norm"
,
&
multi_tensor_l2norm_cuda
,
"Computes L2 norm for a list of contiguous tensors"
);
"Computes L2 norm for a list of contiguous tensors"
);
m
.
def
(
"multi_tensor_l2norm_mp"
,
&
multi_tensor_l2norm_mp_cuda
,
"Computes L2 norm for a list of contiguous tensors"
);
m
.
def
(
"multi_tensor_l2norm_scale"
,
&
multi_tensor_l2norm_scale_cuda
,
m
.
def
(
"multi_tensor_l2norm_scale"
,
&
multi_tensor_l2norm_scale_cuda
,
"Computes L2 norm for a list of contiguous tensors and does scaling"
);
"Computes L2 norm for a list of contiguous tensors and does scaling"
);
m
.
def
(
"multi_tensor_lamb_stage1_cuda"
,
&
multi_tensor_lamb_stage1_cuda
,
m
.
def
(
"multi_tensor_lamb_stage1_cuda"
,
&
multi_tensor_lamb_stage1_cuda
,
...
@@ -142,4 +169,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -142,4 +169,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Compute and apply gradient update to parameters for Adam optimizer"
);
"Compute and apply gradient update to parameters for Adam optimizer"
);
m
.
def
(
"multi_tensor_lamb"
,
&
multi_tensor_lamb_cuda
,
m
.
def
(
"multi_tensor_lamb"
,
&
multi_tensor_lamb_cuda
,
"Computes and apply update for LAMB optimizer"
);
"Computes and apply update for LAMB optimizer"
);
m
.
def
(
"multi_tensor_lamb_mp"
,
&
multi_tensor_lamb_mp_cuda
,
"Computes and apply update for LAMB optimizer"
);
}
}
csrc/megatron/scaled_masked_softmax.cpp
View file @
db92ee13
/* coding=utf-8
/* coding=utf-8
* Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* you may not use this file except in compliance with the License.
...
...
csrc/megatron/scaled_masked_softmax.h
View file @
db92ee13
/* coding=utf-8
/* coding=utf-8
* Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* you may not use this file except in compliance with the License.
...
...
csrc/megatron/scaled_masked_softmax_cuda.cu
View file @
db92ee13
/* coding=utf-8
/* coding=utf-8
* Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* you may not use this file except in compliance with the License.
...
...
csrc/megatron/scaled_upper_triang_masked_softmax.cpp
View file @
db92ee13
/* coding=utf-8
/* coding=utf-8
* Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* you may not use this file except in compliance with the License.
...
...
csrc/megatron/scaled_upper_triang_masked_softmax.h
View file @
db92ee13
/* coding=utf-8
/* coding=utf-8
* Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* you may not use this file except in compliance with the License.
...
...
csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu
View file @
db92ee13
/* coding=utf-8
/* coding=utf-8
* Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* you may not use this file except in compliance with the License.
...
...
csrc/multi_tensor_apply.cuh
View file @
db92ee13
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <THC/THC.h>
#include "compat.h"
#include "compat.h"
#include <assert.h>
#include <assert.h>
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment