Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6fd0b406
Commit
6fd0b406
authored
Nov 25, 2021
by
zihanl
Browse files
merge with main branch
parents
492fdf83
60750922
Changes
92
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
977 additions
and
344 deletions
+977
-344
megatron/fused_kernels/tests/__init__.py
megatron/fused_kernels/tests/__init__.py
+0
-0
megatron/fused_kernels/tests/test_fused_kernels.py
megatron/fused_kernels/tests/test_fused_kernels.py
+300
-0
megatron/global_vars.py
megatron/global_vars.py
+13
-0
megatron/initialize.py
megatron/initialize.py
+31
-19
megatron/model/__init__.py
megatron/model/__init__.py
+1
-0
megatron/model/enums.py
megatron/model/enums.py
+4
-0
megatron/model/fused_bias_gelu.py
megatron/model/fused_bias_gelu.py
+0
-4
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+24
-3
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+73
-45
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+3
-15
megatron/model/language_model.py
megatron/model/language_model.py
+122
-62
megatron/model/module.py
megatron/model/module.py
+33
-8
megatron/model/t5_model.py
megatron/model/t5_model.py
+58
-26
megatron/model/transformer.py
megatron/model/transformer.py
+120
-98
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+4
-2
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+109
-4
megatron/mpu/layers.py
megatron/mpu/layers.py
+55
-9
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+1
-1
megatron/mpu/random.py
megatron/mpu/random.py
+23
-47
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+3
-1
No files found.
megatron/fused_kernels/tests/__init__.py
0 → 100644
View file @
6fd0b406
megatron/fused_kernels/tests/test_fused_kernels.py
0 → 100644
View file @
6fd0b406
import
math
import
torch
from
torch.nn
import
LayerNorm
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.fused_layer_norm
import
MixedFusedLayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.utils
import
attention_mask_func
def
test_load_fused_kernels
():
try
:
import
fused_mix_prec_layer_norm_cuda
import
scaled_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
import
torch
print
(
"[Success] load_fused_kernels"
)
except
ImportError
as
e
:
print
(
"[Fail] load_fused_kernels"
)
raise
e
def
test_fused_softmax
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi"
# 32
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
embedding_output
=
bert
.
embeddings
(
input_ids
=
tokens
[
"input_ids"
].
cuda
(),
position_ids
=
None
,
token_type_ids
=
tokens
[
"token_type_ids"
].
cuda
(),
inputs_embeds
=
None
,
past_key_values_length
=
0
,
)
# (bsz, 1, 1, seq_len)
mask
=
bert
.
get_extended_attention_mask
(
attention_mask
=
tokens
[
"attention_mask"
].
cuda
(),
input_shape
=
tokens
[
"input_ids"
].
shape
,
device
=
bert
.
device
,
)
# (bsz, 1, seq_len, seq_len)
mask
=
mask
.
repeat
(
1
,
1
,
mask
.
size
()[
-
1
],
1
)
attention
=
bert
.
encoder
.
layer
[
0
].
attention
.
self
key_layer
=
attention
.
transpose_for_scores
(
attention
.
key
(
embedding_output
))
query_layer
=
attention
.
transpose_for_scores
(
attention
.
query
(
embedding_output
))
attention_scores
=
torch
.
matmul
(
query_layer
,
key_layer
.
transpose
(
-
1
,
-
2
))
attention_scores
/=
math
.
sqrt
(
key_layer
.
size
()[
-
1
])
fused_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
padding
,
scaled_masked_softmax_fusion
=
True
,
)
.
cuda
()
.
half
()
)
fused_softmax_output
=
fused_softmax
(
attention_scores
,
(
mask
!=
0
),
)
torch_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
padding
,
scaled_masked_softmax_fusion
=
False
,
)
.
cuda
()
.
half
()
)
torch_softmax_output
=
torch_softmax
(
attention_scores
,
(
mask
!=
0
),
)
test_result
=
(
fused_softmax_output
-
torch_softmax_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_fused_softmax"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_fused_softmax"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
test_fused_upper_triangle_mask_softmax
():
gpt
=
GPT2Model
.
from_pretrained
(
"gpt2"
).
cuda
().
half
()
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi"
# 24
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
attention_mask
=
tokens
[
"attention_mask"
].
cuda
()
attention_mask
=
attention_mask
.
view
(
attention_mask
.
size
(
0
),
-
1
)
attention_mask
=
attention_mask
[:,
None
,
None
,
:]
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
attention_mask
=
attention_mask
.
repeat
(
1
,
1
,
attention_mask
.
size
()[
-
1
],
1
)
attn
=
gpt
.
h
[
0
]
hidden_states
=
gpt
.
wte
(
tokens
[
"input_ids"
].
cuda
())
q
,
k
,
v
=
attn
.
attn
.
c_attn
(
hidden_states
).
split
(
768
,
dim
=-
1
)
q
=
attn
.
attn
.
_split_heads
(
q
,
attn
.
attn
.
num_heads
,
attn
.
attn
.
head_dim
)
k
=
attn
.
attn
.
_split_heads
(
k
,
attn
.
attn
.
num_heads
,
attn
.
attn
.
head_dim
)
attn_weights
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
sq
,
sk
=
q
.
size
(
-
2
),
k
.
size
(
-
2
)
causal_mask
=
attn
.
attn
.
bias
[:,
:,
sk
-
sq
:
sk
,
:
sk
].
bool
()
total_mask
=
~
(
causal_mask
&
(
attention_mask
==
0
))
"""
tensor([[[[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, True, True],
[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]]]
"""
fused_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
scaled_masked_softmax_fusion
=
True
,
)
.
cuda
()
.
half
()
)
fused_softmax_output
=
fused_softmax
(
attn_weights
,
total_mask
,
)
torch_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
scaled_masked_softmax_fusion
=
False
,
)
.
cuda
()
.
half
()
)
torch_softmax_output
=
torch_softmax
(
attn_weights
,
total_mask
,
)
test_result
=
(
fused_softmax_output
-
torch_softmax_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_fused_upper_triangle_mask_softmax"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_fused_upper_triangle_mask_softmax"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
test_layer_norm
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi"
# 32
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
# [bsz, seq_len, d_model]
embedding_output
=
(
bert
.
embeddings
(
input_ids
=
tokens
[
"input_ids"
].
cuda
(),
position_ids
=
None
,
token_type_ids
=
tokens
[
"token_type_ids"
].
cuda
(),
inputs_embeds
=
None
,
past_key_values_length
=
0
,
)
.
cuda
()
.
half
()
)
fused_layernorm_layer
=
(
MixedFusedLayerNorm
(
normalized_shape
=
embedding_output
.
size
(
-
1
)).
cuda
().
half
()
)
torch_layernorm_layer
=
(
LayerNorm
(
normalized_shape
=
embedding_output
.
size
(
-
1
)).
cuda
().
half
()
)
fused_output
=
fused_layernorm_layer
(
embedding_output
)
torch_output
=
torch_layernorm_layer
(
embedding_output
)
test_result
=
(
fused_output
-
torch_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_layer_norm"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_layer_norm"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
if
__name__
==
"__main__"
:
try
:
from
transformers
import
BertTokenizer
,
GPT2Tokenizer
from
transformers.models.bert.modeling_bert
import
BertModel
from
transformers.models.gpt2.modeling_gpt2
import
GPT2Model
import
transformers
transformers
.
logging
.
set_verbosity
(
transformers
.
logging
.
FATAL
,
)
except
:
print
(
"
\n
[Fail] Please install `transformers` package to test fused kernels
\n
"
)
exit
(
-
1
)
test_load_fused_kernels
()
test_fused_softmax
()
test_fused_upper_triangle_mask_softmax
()
test_layer_norm
()
megatron/global_vars.py
View file @
6fd0b406
...
...
@@ -21,6 +21,7 @@ import time
import
torch
from
megatron
import
dist_signal_handler
from
megatron.tokenizer
import
build_tokenizer
from
.arguments
import
parse_args
from
.microbatches
import
build_num_microbatches_calculator
...
...
@@ -31,6 +32,7 @@ _GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER
=
None
_GLOBAL_ADLR_AUTORESUME
=
None
_GLOBAL_TIMERS
=
None
_GLOBAL_SIGNAL_HANDLER
=
None
def
get_args
():
...
...
@@ -75,6 +77,14 @@ def get_timers():
_ensure_var_is_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
return
_GLOBAL_TIMERS
def
get_signal_handler
():
_ensure_var_is_initialized
(
_GLOBAL_SIGNAL_HANDLER
,
'signal handler'
)
return
_GLOBAL_SIGNAL_HANDLER
def
_set_signal_handler
():
global
_GLOBAL_SIGNAL_HANDLER
_ensure_var_is_not_initialized
(
_GLOBAL_SIGNAL_HANDLER
,
'signal handler'
)
_GLOBAL_SIGNAL_HANDLER
=
dist_signal_handler
.
DistributedSignalHandler
().
__enter__
()
def
set_global_variables
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
):
...
...
@@ -89,6 +99,9 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
_set_adlr_autoresume
(
args
)
_set_timers
()
if
args
.
exit_signal_handler
:
_set_signal_handler
()
def
_parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
ignore_unknown_args
=
False
):
...
...
megatron/initialize.py
View file @
6fd0b406
...
...
@@ -21,6 +21,7 @@ import time
import
numpy
as
np
import
torch
from
datetime
import
timedelta
from
megatron
import
fused_kernels
from
megatron
import
get_adlr_autoresume
...
...
@@ -63,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
print
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
_set_random_seed
(
args
.
seed
)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options
()
args
=
get_args
()
if
args
.
lazy_mpu_init
:
args
.
use_cpu_initialization
=
True
...
...
@@ -77,9 +81,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init
()
# Initialize memory buffers.
_initialize_mem_buffs
()
# Autoresume.
_init_autoresume
()
...
...
@@ -175,15 +176,11 @@ def _initialize_distributed():
else
:
args
.
local_rank
=
device
torch
.
cuda
.
set_device
(
device
)
# Call the init process
init_method
=
'tcp://'
master_ip
=
os
.
getenv
(
'MASTER_ADDR'
,
'localhost'
)
master_port
=
os
.
getenv
(
'MASTER_PORT'
,
'6000'
)
init_method
+=
master_ip
+
':'
+
master_port
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
init_method
=
init_method
)
# Call the init process
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
timeout
=
timedelta
(
minutes
=
10
))
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
...
...
@@ -193,7 +190,8 @@ def _initialize_distributed():
else
:
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
)
args
.
virtual_pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_split_rank
)
def
_init_autoresume
():
...
...
@@ -229,10 +227,24 @@ def write_args_to_tensorboard():
global_step
=
args
.
iteration
)
def
_initialize_mem_buffs
():
"""Initialize manually allocated static memory."""
args
=
get_args
()
def
_set_jit_fusion_options
():
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
10
):
# nvfuser
torch
.
_C
.
_jit_set_profiling_executor
(
True
)
torch
.
_C
.
_jit_set_profiling_mode
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
False
)
torch
.
_C
.
_jit_set_texpr_fuser_enabled
(
False
)
torch
.
_C
.
_jit_set_nvfuser_enabled
(
True
)
torch
.
_C
.
_debug_set_autodiff_subgraph_inlining
(
False
)
else
:
# legacy pytorch fuser
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
# Initialize memory for checkpointed activations.
if
args
.
distribute_checkpointed_activations
:
mpu
.
init_checkpointed_activations_memory_buffer
()
megatron/model/__init__.py
View file @
6fd0b406
...
...
@@ -21,3 +21,4 @@ from .gpt_model import GPTModel
from
.t5_model
import
T5Model
from
.language_model
import
get_language_model
from
.module
import
Float16Module
from
.enums
import
ModelType
megatron/model/enums.py
View file @
6fd0b406
...
...
@@ -15,6 +15,10 @@
import
enum
class
ModelType
(
enum
.
Enum
):
encoder_or_decoder
=
1
encoder_and_decoder
=
2
class
LayerType
(
enum
.
Enum
):
encoder
=
1
decoder
=
2
...
...
megatron/model/fused_bias_gelu.py
View file @
6fd0b406
...
...
@@ -15,10 +15,6 @@
import
torch
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
...
...
megatron/model/fused_layer_norm.py
View file @
6fd0b406
...
...
@@ -23,6 +23,12 @@ from torch.nn.parameter import Parameter
from
torch.nn
import
init
import
importlib
try
:
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNormFN
HAVE_PERSIST_LAYER_NORM
=
True
except
:
HAVE_PERSIST_LAYER_NORM
=
False
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
None
...
...
@@ -61,13 +67,23 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
no_persist_layer_norm
=
True
):
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# kernel.
persist_ln_hidden_sizes
=
[
1024
,
1536
,
2048
,
2304
,
3072
,
3840
,
4096
,
5120
,
6144
,
8192
,
10240
,
12288
,
12800
,
15360
,
16384
,
18432
,
20480
,
24576
,
25600
,
30720
,
32768
,
40960
,
49152
,
65536
]
if
normalized_shape
not
in
persist_ln_hidden_sizes
or
\
not
HAVE_PERSIST_LAYER_NORM
:
no_persist_layer_norm
=
True
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
normalized_shape
=
(
normalized_shape
,)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
...
...
@@ -75,6 +91,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
reset_parameters
()
self
.
no_persist_layer_norm
=
no_persist_layer_norm
def
reset_parameters
(
self
):
...
...
@@ -85,6 +102,10 @@ class MixedFusedLayerNorm(torch.nn.Module):
def
forward
(
self
,
input
):
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
if
self
.
no_persist_layer_norm
:
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
return
FastLayerNormFN
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
megatron/model/fused_softmax.py
View file @
6fd0b406
...
...
@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
from
megatron.model.enums
import
AttnMaskType
...
...
@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import
scaled_upper_triang_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_upper_triang_masked_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
]
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
...
...
@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import
scaled_upper_triang_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
...
...
@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
mask
,
scale_t
[
0
]
)
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
mask
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
...
...
@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
class
FusedScaleMaskSoftmax
(
nn
.
Module
):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def
__init__
(
...
...
@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
\
'both fp16 and bf16 flags cannot be active at the same time.'
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
"both fp16 and bf16 flags cannot be active at the same time."
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
...
...
@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert
(
self
.
scale
is
None
or
softmax_in_fp32
),
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
data_size
=
input
.
size
()
query_seq_len
=
data_size
[
-
2
]
key_seq_len
=
data_size
[
-
1
]
attn_batch_size
=
data_size
[
0
]
*
data_size
[
1
]
# constraints on various tensor dimensions to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
key_seq_len
>
16
and
key_seq_len
<=
2048
and
\
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# invoke custom kernel
if
self
.
input_in_float16
and
mask
is
not
None
and
\
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
query_seq_len
==
key_seq_len
,
\
"causal mask is only for self attention"
input
=
input
.
view
(
-
1
,
query_seq_len
,
key_seq_len
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
probs
.
view
(
*
data_size
)
else
:
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
if
self
.
is_kernel_available
(
mask
,
*
input
.
size
()):
return
self
.
forward_fused_softmax
(
input
,
mask
)
else
:
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
return
self
.
forward_torch_softmax
(
input
,
mask
)
def
is_kernel_available
(
self
,
mask
,
b
,
np
,
sq
,
sk
):
attn_batches
=
b
*
np
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
mask
is
not
None
# mask tensor must not be None
and
16
<
sk
<=
2048
# sk must be 16 ~ 2048
and
sq
%
4
==
0
# sq must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
2048
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
attn_batches
%
batch_per_block
==
0
:
return
True
else
:
if
sq
%
batch_per_block
==
0
:
return
True
return
False
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
def
forward_fused_softmax
(
self
,
input
,
mask
):
b
,
np
,
sq
,
sk
=
input
.
size
()
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
sq
==
sk
,
"causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input
=
input
.
view
(
-
1
,
sq
,
sk
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
return
probs
.
view
(
b
,
np
,
sq
,
sk
)
else
:
# input is 4D tensor (b, np, sq, sk)
return
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
def
forward_torch_softmax
(
self
,
input
,
mask
):
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
@
staticmethod
def
get_batch_per_block
(
sq
,
sk
,
b
,
np
):
import
scaled_masked_softmax_cuda
return
scaled_masked_softmax_cuda
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
megatron/model/gpt_model.py
View file @
6fd0b406
...
...
@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal
def
post_language_model_processing
(
lm_output
,
labels
,
logit_weights
,
get_key_value
,
parallel_output
,
forward_method_parallel_output
,
parallel_output
,
fp16_lm_cross_entropy
):
if
get_key_value
:
lm_output
,
presents
=
lm_output
# Output.
if
forward_method_parallel_output
is
not
None
:
parallel_output
=
forward_method_parallel_output
output
=
parallel_lm_logits
(
lm_output
,
logit_weights
,
parallel_output
)
if
get_key_value
:
output
=
[
output
,
presents
]
if
labels
is
None
:
return
output
else
:
...
...
@@ -90,23 +82,19 @@ class GPTModel(MegatronModule):
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
tokentype_ids
=
None
,
inference_params
=
None
):
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
inference_params
=
inference_params
)
if
self
.
post_process
:
return
post_language_model_processing
(
lm_output
,
labels
,
self
.
word_embeddings_weight
(),
get_key_value
,
self
.
parallel_output
,
forward_method_parallel_output
,
self
.
fp16_lm_cross_entropy
)
else
:
return
lm_output
...
...
megatron/model/language_model.py
View file @
6fd0b406
...
...
@@ -45,7 +45,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def
get_language_model
(
num_tokentypes
,
add_pooler
,
encoder_attn_mask_type
,
init_method
=
None
,
scaled_init_method
=
None
,
add_decoder
=
False
,
scaled_init_method
=
None
,
add_encoder
=
True
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
pre_process
=
True
,
post_process
=
True
):
"""Build language model and return along with the key to save."""
...
...
@@ -64,6 +65,7 @@ def get_language_model(num_tokentypes, add_pooler,
scaled_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
num_tokentypes
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
,
decoder_attn_mask_type
=
decoder_attn_mask_type
,
add_pooler
=
add_pooler
,
...
...
@@ -159,6 +161,16 @@ class Embedding(MegatronModule):
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
def
zero_parameters
(
self
):
"""Zero out all parameters in embedding."""
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
self
.
position_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
position_embeddings
.
weight
.
shared
=
True
if
self
.
num_tokentypes
>
0
:
self
.
tokentype_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
tokentype_embeddings
.
weight
.
shared
=
True
def
add_tokentype_embeddings
(
self
,
num_tokentypes
):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
...
...
@@ -273,6 +285,7 @@ class TransformerLanguageModel(MegatronModule):
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
0
,
add_encoder
=
True
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
add_pooler
=
False
,
...
...
@@ -286,10 +299,12 @@ class TransformerLanguageModel(MegatronModule):
self
.
hidden_size
=
args
.
hidden_size
self
.
num_tokentypes
=
num_tokentypes
self
.
init_method
=
init_method
self
.
add_encoder
=
add_encoder
self
.
encoder_attn_mask_type
=
encoder_attn_mask_type
self
.
add_decoder
=
add_decoder
self
.
decoder_attn_mask_type
=
decoder_attn_mask_type
self
.
add_pooler
=
add_pooler
self
.
encoder_hidden_state
=
None
# Embeddings.
if
self
.
pre_process
:
...
...
@@ -302,25 +317,37 @@ class TransformerLanguageModel(MegatronModule):
self
.
_embedding_key
=
'embedding'
# Transformer.
self
.
encoder
=
ParallelTransformer
(
self
.
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
self
.
encoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
_encoder_key
=
'encoder'
# Decoder
# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
if
self
.
add_encoder
:
self
.
encoder
=
ParallelTransformer
(
self
.
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
self
.
encoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
_encoder_key
=
'encoder'
else
:
self
.
encoder
=
None
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if
self
.
add_decoder
:
# Temporary assertion until we verify correctness of pipeline parallelism
# implementation of T5.
assert
args
.
pipeline_model_parallel_size
==
1
,
\
'pipeline parallelism is not supported in the presence of decoder'
self
.
decoder
=
ParallelTransformer
(
self
.
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
decoder
,
self_attn_mask_type
=
self
.
decoder_attn_mask_type
)
self_attn_mask_type
=
self
.
decoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
_decoder_key
=
'decoder'
else
:
self
.
decoder
=
None
if
self
.
post_process
:
# Pooler.
...
...
@@ -330,28 +357,55 @@ class TransformerLanguageModel(MegatronModule):
def
set_input_tensor
(
self
,
input_tensor
):
""" See megatron.model.transformer.set_input_tensor()"""
self
.
encoder
.
set_input_tensor
(
input_tensor
)
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
if
self
.
add_encoder
and
self
.
add_decoder
:
assert
len
(
input_tensor
)
==
1
,
\
'input_tensor should only be length 1 for stage with both encoder and decoder'
self
.
encoder
.
set_input_tensor
(
input_tensor
[
0
])
elif
self
.
add_encoder
:
assert
len
(
input_tensor
)
==
1
,
\
'input_tensor should only be length 1 for stage with only encoder'
self
.
encoder
.
set_input_tensor
(
input_tensor
[
0
])
elif
self
.
add_decoder
:
if
len
(
input_tensor
)
==
2
:
self
.
decoder
.
set_input_tensor
(
input_tensor
[
0
])
self
.
encoder_hidden_state
=
input_tensor
[
1
]
elif
len
(
input_tensor
)
==
1
:
self
.
decoder
.
set_input_tensor
(
None
)
self
.
encoder_hidden_state
=
input_tensor
[
0
]
else
:
raise
Exception
(
'input_tensor must have either length 1 or 2'
)
else
:
raise
Exception
(
'Stage must have at least either encoder or decoder'
)
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
inference_params
=
None
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
# Embedding
s
.
# E
ncoder e
mbedding.
if
self
.
pre_process
:
embedding_output
=
self
.
embedding
(
enc_input_ids
,
enc_position_ids
,
tokentype_ids
=
tokentype_ids
)
encoder_input
=
embedding_output
encoder_input
=
self
.
embedding
(
enc_input_ids
,
enc_position_ids
,
tokentype_ids
=
tokentype_ids
)
else
:
encoder_input
=
None
# encoder.
#
Run
encoder.
if
enc_hidden_states
is
None
:
encoder_output
=
self
.
encoder
(
encoder_input
,
enc_attn_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
if
self
.
encoder
is
not
None
:
encoder_output
=
self
.
encoder
(
encoder_input
,
enc_attn_mask
,
inference_params
=
inference_params
)
else
:
encoder_output
=
self
.
encoder_hidden_state
else
:
encoder_output
=
enc_hidden_states
.
to
(
encoder_input
.
dtype
)
...
...
@@ -369,16 +423,20 @@ class TransformerLanguageModel(MegatronModule):
else
:
return
encoder_output
# Decoder Embedding
dec_embedding_output
=
self
.
embedding
(
dec_input_ids
,
dec_position_ids
)
# decoder
decoder_output
=
self
.
decoder
(
dec_embedding_output
,
dec_attn_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
)
# Decoder embedding.
if
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
dec_input_ids
,
dec_position_ids
)
else
:
decoder_input
=
None
# Run decoder.
decoder_output
=
self
.
decoder
(
decoder_input
,
dec_attn_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
if
self
.
add_pooler
and
self
.
post_process
:
return
decoder_output
,
encoder_output
,
pooled_output
...
...
@@ -394,9 +452,10 @@ class TransformerLanguageModel(MegatronModule):
state_dict_
[
self
.
_embedding_key
]
\
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_encoder_key
]
\
=
self
.
encoder
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
add_encoder
:
state_dict_
[
self
.
_encoder_key
]
\
=
self
.
encoder
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
post_process
:
if
self
.
add_pooler
:
state_dict_
[
self
.
_pooler_key
]
\
...
...
@@ -425,38 +484,39 @@ class TransformerLanguageModel(MegatronModule):
self
.
embedding
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Encoder.
if
self
.
_encoder_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_encoder_key
]
# for backward compatibility.
elif
'transformer'
in
state_dict
:
state_dict_
=
state_dict
[
'transformer'
]
else
:
# for backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
'transformer.'
in
key
:
state_dict_
[
key
.
split
(
'transformer.'
)[
1
]]
=
state_dict
[
key
]
# for backward compatibility.
state_dict_self_attention
=
{}
for
key
in
state_dict_
.
keys
():
if
'.attention.'
in
key
:
state_dict_self_attention
[
key
.
replace
(
".attention."
,
".self_attention."
)]
=
state_dict_
[
key
]
if
self
.
add_encoder
:
if
self
.
_encoder_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_encoder_key
]
# For backward compatibility.
elif
'transformer'
in
state_dict
:
state_dict_
=
state_dict
[
'transformer'
]
else
:
state_dict_self_attention
[
key
]
=
state_dict_
[
key
]
state_dict_
=
state_dict_self_attention
self
.
encoder
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# For backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
'transformer.'
in
key
:
state_dict_
[
key
.
split
(
'transformer.'
)[
1
]]
=
state_dict
[
key
]
# For backward compatibility.
state_dict_self_attention
=
{}
for
key
in
state_dict_
.
keys
():
if
'.attention.'
in
key
:
state_dict_self_attention
[
key
.
replace
(
".attention."
,
".self_attention."
)]
=
state_dict_
[
key
]
else
:
state_dict_self_attention
[
key
]
=
state_dict_
[
key
]
state_dict_
=
state_dict_self_attention
self
.
encoder
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Pooler.
if
self
.
post_process
:
# pooler
if
self
.
add_pooler
:
assert
'pooler'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
strict
=
strict
)
#
d
ecoder
#
D
ecoder
.
if
self
.
add_decoder
:
assert
'decoder'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
...
...
megatron/model/module.py
View file @
6fd0b406
...
...
@@ -51,15 +51,14 @@ class MegatronModule(torch.nn.Module):
def
word_embeddings_weight
(
self
):
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
not
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
or
\
mpu
.
get_pipeline_model_parallel_world_size
()
==
1
:
return
self
.
language_model
.
embedding
.
word_embeddings
.
weight
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
:
else
:
if
not
self
.
share_word_embeddings
:
raise
Exception
(
'word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false'
)
return
self
.
word_embeddings
.
weight
raise
Exception
(
'word_embeddings_weight() should be '
'called for first and last stage only'
)
def
initialize_word_embeddings
(
self
,
init_method_normal
):
...
...
@@ -69,12 +68,12 @@ class MegatronModule(torch.nn.Module):
'share_word_embeddings is false'
)
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism.
If we aren't using pipeline
#
parallelism there is nothing to do
.
# when we are using pipeline parallelism.
Nothing to do if we aren't
#
using pipeline parallelism
.
if
args
.
pipeline_model_parallel_size
==
1
:
return
# Parameters are shared between the word embeddings layer, and the
# Parameters are shared between the word embeddings layer
s
, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
...
...
@@ -97,12 +96,34 @@ class MegatronModule(torch.nn.Module):
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if
not
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
and
\
not
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
\
mpu
.
is_rank_in_embedding_group
():
self
.
language_model
.
embedding
.
zero_parameters
()
# Ensure that first and last stages have the same initial parameter
# values.
if
torch
.
distributed
.
is_initialized
():
if
mpu
.
is_
pipeline_first_stage
()
or
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_
rank_in_embedding_group
():
torch
.
distributed
.
all_reduce
(
self
.
word_embeddings_weight
().
data
,
group
=
mpu
.
get_embedding_group
())
# All-reduce other embeddings as well as necessary. The last stage
# does not have these other embeddings, so just create placeholder
# tensors of the right shape with all zeros.
# NOTE: We don't currently support T5 with the interleaved schedule.
if
args
.
pipeline_model_parallel_split_rank
is
not
None
:
# TODO: Support tokentype embedding.
dimensions
=
(
args
.
max_position_embeddings
,
args
.
hidden_size
)
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
position_embeddings
=
torch
.
nn
.
Embedding
(
*
dimensions
).
cuda
()
position_embeddings
.
weight
.
data
.
fill_
(
0
)
else
:
self
.
language_model
.
embedding
.
cuda
()
position_embeddings
=
self
.
language_model
.
embedding
.
position_embeddings
torch
.
distributed
.
all_reduce
(
position_embeddings
.
weight
.
data
,
group
=
mpu
.
get_embedding_group
())
else
:
print
(
"WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
...
...
@@ -166,6 +187,10 @@ class Float16Module(MegatronModule):
self
.
float16_convertor
=
float16_convertor
def
set_input_tensor
(
self
,
input_tensor
):
return
self
.
module
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
if
mpu
.
is_pipeline_first_stage
():
inputs
=
fp32_to_float16
(
inputs
,
self
.
float16_convertor
)
...
...
megatron/model/t5_model.py
View file @
6fd0b406
...
...
@@ -86,7 +86,13 @@ class T5LMHead(MegatronModule):
class
T5Model
(
MegatronModule
):
"""T5 Language model."""
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
):
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
,
add_encoder
=
True
,
add_decoder
=
True
):
super
(
T5Model
,
self
).
__init__
()
args
=
get_args
()
...
...
@@ -95,19 +101,29 @@ class T5Model(MegatronModule):
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
add_encoder
=
add_encoder
self
.
add_decoder
=
add_decoder
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
add_decoder
=
True
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
scaled_init_method
=
scaled_init_method
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
lm_head
=
T5LMHead
(
self
.
language_model
.
embedding
.
word_embeddings
.
weight
.
size
(
0
),
parallel_output
)
self
.
_lm_head_key
=
'lm_head'
self
.
initialize_word_embeddings
(
init_method_normal
)
if
self
.
post_process
and
self
.
add_decoder
:
self
.
lm_head
=
T5LMHead
(
self
.
word_embeddings_weight
().
size
(
0
),
parallel_output
)
self
.
_lm_head_key
=
'lm_head'
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
...
...
@@ -134,22 +150,28 @@ class T5Model(MegatronModule):
tokentype_ids
=
tokentype_ids
,
enc_hidden_states
=
enc_hidden_states
)
decoder_output
,
encoder_output
=
lm_output
# Output.
lm_logits
=
self
.
lm_head
(
decoder_output
,
self
.
language_model
.
embedding
.
word_embeddings
.
weight
)
if
self
.
post_process
and
self
.
add_decoder
:
decoder_output
,
encoder_output
=
lm_output
# Output.
lm_logits
=
self
.
lm_head
(
decoder_output
,
self
.
word_embeddings
_
weight
()
)
if
lm_labels
is
None
:
return
lm_logits
,
encoder_output
else
:
if
self
.
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
if
lm_labels
is
None
:
return
lm_logits
else
:
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
return
lm_loss
,
encoder_output
if
self
.
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
return
lm_loss
elif
self
.
add_decoder
and
not
self
.
add_encoder
:
decoder_output
,
encoder_output
=
lm_output
return
decoder_output
else
:
encoder_output
=
lm_output
return
encoder_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
...
...
@@ -160,9 +182,14 @@ class T5Model(MegatronModule):
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_lm_head_key
]
\
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
post_process
and
self
.
add_decoder
:
state_dict_
[
self
.
_lm_head_key
]
\
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
# Save word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
and
self
.
add_decoder
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
...
...
@@ -170,5 +197,10 @@ class T5Model(MegatronModule):
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
self
.
lm_head
.
load_state_dict
(
state_dict
[
self
.
_lm_head_key
],
strict
=
strict
)
if
self
.
post_process
and
self
.
add_decoder
:
self
.
lm_head
.
load_state_dict
(
state_dict
[
self
.
_lm_head_key
],
strict
=
strict
)
# Load word embeddings.
if
self
.
post_process
and
not
self
.
pre_process
and
self
.
add_decoder
:
self
.
word_embeddings
.
load_state_dict
(
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
megatron/model/transformer.py
View file @
6fd0b406
...
...
@@ -21,17 +21,12 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
mpu
from
.module
import
MegatronModule
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model.enums
import
AttnMaskType
,
ModelType
,
LayerType
,
AttnType
from
megatron.model
import
LayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
# flags required to enable jit fusion kernels
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
""" We use the following notation throughout this file:
h: hidden size
...
...
@@ -53,8 +48,7 @@ class ParallelMLP(MegatronModule):
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
state back into h hidden dimension.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
...
...
@@ -84,7 +78,6 @@ class ParallelMLP(MegatronModule):
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
def
forward
(
self
,
hidden_states
):
# [s, b, 4hp]
...
...
@@ -125,6 +118,7 @@ class ParallelAttention(MegatronModule):
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
self
.
params_dtype
=
args
.
params_dtype
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
...
...
@@ -185,10 +179,40 @@ class ParallelAttention(MegatronModule):
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
,
encoder_output
=
None
):
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
):
return
torch
.
empty
(
inference_max_sequence_len
,
batch_size
,
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
dtype
=
self
.
params_dtype
,
device
=
torch
.
cuda
.
current_device
())
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
inference_params
=
None
):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if
inference_params
:
if
self
.
layer_number
not
in
inference_params
.
key_value_memory_dict
:
inf_max_seq_len
=
inference_params
.
max_sequence_len
inf_max_batch_size
=
inference_params
.
max_batch_size
inference_key_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
)
inference_value_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
)
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_key_memory
,
inference_value_memory
)
else
:
inference_key_memory
,
inference_value_memory
=
\
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
# =====================
# Query, Key, and Value
# =====================
...
...
@@ -229,18 +253,28 @@ class ParallelAttention(MegatronModule):
self
.
hidden_size_per_attention_head
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ==================================
# Adjust key and value for inference
# ==================================
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
key_layer
=
torch
.
cat
((
past_key
.
type_as
(
key_layer
),
key_layer
),
dim
=
0
)
value_layer
=
torch
.
cat
((
past_value
.
type_as
(
value_layer
),
value_layer
),
dim
=
0
)
if
get_key_value
:
present
=
(
key_layer
,
value_layer
)
if
inference_params
:
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
key_layer
.
size
(
1
)
assert
batch_end
<=
inference_key_memory
.
size
(
1
)
sequence_start
=
inference_params
.
sequence_len_offset
sequence_end
=
sequence_start
+
key_layer
.
size
(
0
)
assert
sequence_end
<=
inference_key_memory
.
size
(
0
)
# Copy key and values.
inference_key_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
key_layer
inference_value_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
value_layer
key_layer
=
inference_key_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
value_layer
=
inference_value_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
# ===================================
# Raw attention scores. [b, np, s, s]
...
...
@@ -277,22 +311,6 @@ class ParallelAttention(MegatronModule):
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
if
get_key_value
:
with
torch
.
no_grad
():
if
layer_past
is
not
None
:
attention_mask
=
attention_mask
[
...,
attention_scores
.
size
(
3
)
-
1
,
:
attention_scores
.
size
(
3
)].
unsqueeze
(
2
)
else
:
attention_mask
=
attention_mask
[
...,
:
attention_scores
.
size
(
3
),
:
attention_scores
.
size
(
3
)]
# ===========================
# Attention probs and dropout
...
...
@@ -348,9 +366,6 @@ class ParallelAttention(MegatronModule):
output
,
bias
=
self
.
dense
(
context_layer
)
if
get_key_value
:
output
=
[
output
,
present
]
return
output
,
bias
...
...
@@ -368,14 +383,18 @@ def get_bias_dropout_add(training):
@
torch
.
jit
.
script
def
bias_dropout_add_fused_train
(
x
,
bias
,
residual
,
prob
):
# type: (Tensor, Tensor, Tensor, float) -> Tensor
def
bias_dropout_add_fused_train
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
True
)
@
torch
.
jit
.
script
def
bias_dropout_add_fused_inference
(
x
,
bias
,
residual
,
prob
):
# type: (Tensor, Tensor, Tensor, float) -> Tensor
def
bias_dropout_add_fused_inference
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
False
)
...
...
@@ -404,7 +423,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# Self attention.
self
.
self_attention
=
ParallelAttention
(
...
...
@@ -419,7 +439,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
...
...
@@ -430,7 +451,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# MLP
self
.
mlp
=
ParallelMLP
(
init_method
,
...
...
@@ -438,20 +460,17 @@ class ParallelTransformerLayer(MegatronModule):
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
layer_past
=
None
,
get_key_value
=
Fals
e
):
inference_params
=
Non
e
):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
if
get_key_value
:
attention_output
,
presents
=
attention_output
self
.
self_attention
(
layernorm_output
,
attention_mask
,
inference_params
=
inference_params
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
...
...
@@ -521,9 +540,6 @@ class ParallelTransformerLayer(MegatronModule):
residual
,
self
.
hidden_dropout
)
if
get_key_value
:
output
=
[
output
,
presents
]
return
output
...
...
@@ -544,13 +560,13 @@ class ParallelTransformer(MegatronModule):
self
.
input_tensor
=
None
# Store activation checkpoiting flag.
self
.
checkpoint_activations
=
args
.
checkpoint_activations
self
.
checkpoint_num_layers
=
args
.
checkpoint_num_layers
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
self
.
activations_checkpoint_num_layers
=
args
.
activations_checkpoint_num_layers
self
.
distribute_checkpointed_activations
=
args
.
distribute_checkpointed_activations
# Number of layers.
assert
args
.
num_layers
%
mpu
.
get_pipeline_model_parallel_world_size
()
==
0
,
\
'num_layers must be divisible by pipeline_model_parallel_size'
self
.
num_layers
=
args
.
num_layers
//
mpu
.
get_pipeline_model_parallel_world_size
()
self
.
num_layers
=
mpu
.
get_num_layers
(
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
# Transformer layers.
def
build_layer
(
layer_number
):
...
...
@@ -589,7 +605,8 @@ class ParallelTransformer(MegatronModule):
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
...
...
@@ -609,14 +626,32 @@ class ParallelTransformer(MegatronModule):
return
x_
return
custom_forward
# Make sure memory is freed.
mpu
.
reset_checkpointed_activations_memory_buffer
()
l
=
0
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
checkpoint_num_layers
if
self
.
activations_checkpoint_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l
=
0
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
activations_checkpoint_num_layers
),
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
activations_checkpoint_num_layers
elif
self
.
activations_checkpoint_method
==
'block'
:
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
activations_checkpoint_num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
1
),
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
raise
ValueError
(
"Invalid activation checkpoint method."
)
return
hidden_states
...
...
@@ -630,18 +665,14 @@ class ParallelTransformer(MegatronModule):
forward_step_func"""
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
):
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# Checks.
if
layer_past
is
not
None
:
assert
get_key_value
,
\
'for not None values in layer_past, '
\
'expected get_key_value to be set'
if
get_key_value
:
assert
not
self
.
checkpoint_activations
,
\
'get_key_value does not work with '
\
'activation checkpointing'
if
inference_params
:
assert
self
.
activations_checkpoint_method
is
None
,
\
'inference does not work with activation checkpointing'
if
self
.
pre_process
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
...
...
@@ -658,28 +689,21 @@ class ParallelTransformer(MegatronModule):
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
checkpoint_activati
on
s
:
if
self
.
activations_checkpoint_method
is
not
N
on
e
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
if
get_key_value
:
presents
=
[]
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
past
=
None
if
layer_past
is
not
None
:
past
=
layer_past
[
index
]
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
layer_past
=
past
,
get_key_value
=
get_key_value
)
if
get_key_value
:
hidden_states
,
present
=
hidden_states
presents
.
append
(
present
)
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
# Final layer norm.
if
self
.
post_process
:
...
...
@@ -688,7 +712,5 @@ class ParallelTransformer(MegatronModule):
output
=
self
.
final_layernorm
(
hidden_states
)
else
:
output
=
hidden_states
if
get_key_value
:
output
=
[
output
,
presents
]
return
output
megatron/mpu/__init__.py
View file @
6fd0b406
...
...
@@ -31,6 +31,10 @@ from .initialize import get_pipeline_model_parallel_group
from
.initialize
import
get_tensor_model_parallel_rank
,
set_tensor_model_parallel_rank
from
.initialize
import
get_pipeline_model_parallel_rank
,
set_pipeline_model_parallel_rank
from
.initialize
import
is_pipeline_first_stage
,
is_pipeline_last_stage
from
.initialize
import
is_rank_in_embedding_group
from
.initialize
import
is_pipeline_stage_before_split
,
is_pipeline_stage_after_split
from
.initialize
import
is_pipeline_stage_at_split
from
.initialize
import
get_num_layers
from
.initialize
import
get_tensor_model_parallel_src_rank
from
.initialize
import
get_pipeline_model_parallel_first_rank
from
.initialize
import
get_pipeline_model_parallel_last_rank
...
...
@@ -56,9 +60,7 @@ from .mappings import scatter_to_tensor_model_parallel_region
from
.random
import
checkpoint
from
.random
import
get_cuda_rng_tracker
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.random
import
gather_split_1d_tensor
from
.random
import
split_tensor_into_1d_equal_chunks
...
...
megatron/mpu/initialize.py
View file @
6fd0b406
...
...
@@ -34,6 +34,7 @@ _DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
...
...
@@ -41,8 +42,11 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS
=
None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
# rank when broadcasting from the first or last pipeline stage
.
_PIPELINE_GLOBAL_RANKS
=
None
def
is_unitialized
():
...
...
@@ -52,13 +56,19 @@ def is_unitialized():
def
initialize_model_parallel
(
tensor_model_parallel_size_
=
1
,
pipeline_model_parallel_size_
=
1
,
virtual_pipeline_model_parallel_size_
=
None
):
virtual_pipeline_model_parallel_size_
=
None
,
pipeline_model_parallel_split_rank_
=
None
):
"""
Initialize model data parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
rank in pipeline with split point.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
...
...
@@ -101,6 +111,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
virtual_pipeline_model_parallel_size_
if
pipeline_model_parallel_split_rank_
is
not
None
:
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
pipeline_model_parallel_split_rank_
rank
=
torch
.
distributed
.
get_rank
()
# Build the data-parallel groups.
...
...
@@ -148,6 +162,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
,
\
'pipeline model parallel group is already initialized'
global
_EMBEDDING_GROUP
global
_EMBEDDING_GLOBAL_RANKS
assert
_EMBEDDING_GROUP
is
None
,
\
'embedding group is already initialized'
for
i
in
range
(
num_pipeline_model_parallel_groups
):
...
...
@@ -161,11 +176,18 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# first and last stages).
if
len
(
ranks
)
>
1
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
if
pipeline_model_parallel_split_rank_
is
not
None
and
\
pipeline_model_parallel_split_rank_
not
in
embedding_ranks
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank_
],
ranks
[
-
1
]]
else
:
embedding_ranks
=
ranks
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
)
if
rank
in
embedding_ranks
:
_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_EMBEDDING_GLOBAL_RANKS
=
embedding_ranks
def
model_parallel_is_initialized
():
...
...
@@ -268,6 +290,30 @@ def get_pipeline_model_parallel_rank():
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
def
get_num_layers
(
args
,
is_encoder_and_decoder_model
):
"""Compute the number of transformer layers resident on the current rank."""
if
get_pipeline_model_parallel_world_size
()
>
1
:
if
is_encoder_and_decoder_model
:
assert
args
.
pipeline_model_parallel_split_rank
is
not
None
num_ranks_in_encoder
=
args
.
pipeline_model_parallel_split_rank
num_ranks_in_decoder
=
get_pipeline_model_parallel_world_size
()
-
num_ranks_in_encoder
assert
args
.
num_layers
%
num_ranks_in_encoder
==
0
,
\
'num_layers must be divisible by number of ranks given to encoder'
assert
args
.
num_layers
%
num_ranks_in_decoder
==
0
,
\
'num_layers must be divisible by number of ranks given to decoder'
if
is_pipeline_stage_before_split
():
num_layers
=
args
.
num_layers
//
num_ranks_in_encoder
else
:
num_layers
=
args
.
num_layers
//
num_ranks_in_decoder
else
:
assert
args
.
num_layers
%
get_pipeline_model_parallel_world_size
()
==
0
,
\
'num_layers must be divisible by pipeline_model_parallel_size'
num_layers
=
args
.
num_layers
//
get_pipeline_model_parallel_world_size
()
else
:
num_layers
=
args
.
num_layers
return
num_layers
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
...
...
@@ -290,6 +336,61 @@ def is_pipeline_last_stage(ignore_virtual=False):
get_pipeline_model_parallel_world_size
()
-
1
)
def
is_rank_in_embedding_group
(
ignore_virtual
=
False
):
"""Return true if current rank is in embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
global
_EMBEDDING_GLOBAL_RANKS
if
ignore_virtual
:
return
rank
in
_EMBEDDING_GLOBAL_RANKS
if
rank
in
_EMBEDDING_GLOBAL_RANKS
:
if
rank
==
_EMBEDDING_GLOBAL_RANKS
[
0
]:
return
is_pipeline_first_stage
(
ignore_virtual
=
False
)
elif
rank
==
_EMBEDDING_GLOBAL_RANKS
[
-
1
]:
return
is_pipeline_last_stage
(
ignore_virtual
=
False
)
else
:
return
True
return
False
def
is_pipeline_stage_before_split
(
rank
=
None
):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if
get_pipeline_model_parallel_world_size
()
==
1
:
return
True
if
rank
is
None
:
rank
=
get_pipeline_model_parallel_rank
()
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
is
None
:
return
True
if
rank
<
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
:
return
True
return
False
def
is_pipeline_stage_after_split
(
rank
=
None
):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if
get_pipeline_model_parallel_world_size
()
==
1
:
return
True
if
rank
is
None
:
rank
=
get_pipeline_model_parallel_rank
()
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
is
None
:
return
True
if
rank
>=
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
:
return
True
return
False
def
is_pipeline_stage_at_split
():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank
=
get_pipeline_model_parallel_rank
()
return
is_pipeline_stage_before_split
(
rank
)
and
\
is_pipeline_stage_after_split
(
rank
+
1
)
def
get_virtual_pipeline_model_parallel_rank
():
"""Return the virtual pipeline-parallel rank."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
...
...
@@ -356,9 +457,13 @@ def get_data_parallel_rank():
def
destroy_model_parallel
():
"""Set the groups to none."""
global
_MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP
=
None
global
_TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP
=
None
global
_PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
global
_DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP
=
None
global
_EMBEDDING_GROUP
_EMBEDDING_GROUP
=
None
megatron/mpu/layers.py
View file @
6fd0b406
...
...
@@ -27,6 +27,7 @@ from torch.nn.parameter import Parameter
from
.initialize
import
get_tensor_model_parallel_rank
from
.initialize
import
get_tensor_model_parallel_world_size
from
.initialize
import
get_tensor_model_parallel_group
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
...
...
@@ -198,6 +199,37 @@ class VocabParallelEmbedding(torch.nn.Module):
return
output
class
ColumnParallelLinearWithAsyncAllreduce
(
torch
.
autograd
.
Function
):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
output
=
torch
.
matmul
(
input
,
weight
.
t
())
if
bias
is
not
None
:
output
=
output
+
bias
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
grad_input
=
grad_output
.
matmul
(
weight
)
# Asyncronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
grad_weight
=
grad_output
.
t
().
matmul
(
input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
"""Linear layer with column parallelism.
...
...
@@ -256,7 +288,7 @@ class ColumnParallelLinear(torch.nn.Module):
device
=
torch
.
cuda
.
current_device
(),
dtype
=
args
.
params_dtype
))
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
stride
)
if
bias
:
if
args
.
use_cpu_initialization
:
self
.
bias
=
Parameter
(
torch
.
empty
(
...
...
@@ -272,21 +304,35 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
bias
.
zero_
()
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
async_tensor_model_parallel_allreduce
=
(
not
args
.
no_async_tensor_model_parallel_allreduce
and
world_size
>
1
)
def
forward
(
self
,
input_
):
# Set up backprop all-reduce.
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
if
self
.
async_tensor_model_parallel_allreduce
:
input_shape
=
input_
.
shape
input_
=
input_
.
view
(
input_shape
[
0
]
*
input_shape
[
1
],
input_shape
[
2
])
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel
=
ColumnParallelLinearWithAsyncAllreduce
.
apply
(
input_
,
self
.
weight
,
bias
)
output_parallel
=
output_parallel
.
view
(
input_shape
[
0
],
input_shape
[
1
],
output_parallel
.
shape
[
1
])
else
:
# Set up backprop all-reduce.
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
else
:
output
=
output_parallel
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
...
...
@@ -316,8 +362,8 @@ class RowParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimation
s
where bias
can be fused with other elementwise operations.
w
e skip
skip_bias_add: This was added to enable performance optim
iz
ation where bias
can be fused with other elementwise operations.
W
e skip
adding bias but instead return it.
"""
...
...
megatron/mpu/mappings.py
View file @
6fd0b406
...
...
@@ -20,7 +20,7 @@ from .utils import split_tensor_along_last_dim
def
_reduce
(
input_
):
"""All-reduce the
the
input tensor across model parallel group."""
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if
get_tensor_model_parallel_world_size
()
==
1
:
...
...
megatron/mpu/random.py
View file @
6fd0b406
...
...
@@ -24,7 +24,6 @@ from torch import _C
from
torch.cuda
import
_lazy_call
,
device
as
device_ctx_manager
from
torch.utils.checkpoint
import
detach_variable
from
megatron
import
get_args
from
megatron.memory
import
allocate_mem_buff
from
.initialize
import
get_data_parallel_rank
...
...
@@ -37,37 +36,6 @@ from .initialize import get_tensor_model_parallel_world_size
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
None
def
init_checkpointed_activations_memory_buffer
():
"""Initializ the memory buffer for the checkpointed activations."""
args
=
get_args
()
per_layer
=
args
.
micro_batch_size
*
args
.
max_position_embeddings
*
\
args
.
hidden_size
//
args
.
tensor_model_parallel_size
assert
args
.
num_layers
%
args
.
checkpoint_num_layers
==
0
,
\
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers
=
args
.
num_layers
//
args
.
checkpoint_num_layers
numel
=
per_layer
*
num_checkpointer_layers
dtype
=
torch
.
half
if
not
args
.
fp16
:
dtype
=
torch
.
float
global
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
None
,
\
'checkpointed activations memory buffer is already allocated.'
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
allocate_mem_buff
(
'checkpointed activations'
,
numel
,
dtype
,
track_usage
=
False
)
def
reset_checkpointed_activations_memory_buffer
():
"""Reset the memory used for checkpointing."""
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
None
:
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
.
reset
()
def
_set_cuda_rng_state
(
new_state
,
device
=-
1
):
"""Sets the random number generator state of the current GPU.
...
...
@@ -101,14 +69,21 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call
(
cb
)
def
split_tensor_into_1d_equal_chunks
(
tensor
):
def
split_tensor_into_1d_equal_chunks
(
tensor
,
new_buffer
=
False
):
"""Break a tensor into equal 1D chunks."""
data
=
tensor
.
view
(
-
1
)
partition_size
=
torch
.
numel
(
data
)
//
get_tensor_model_parallel_world_size
()
partition_size
=
torch
.
numel
(
tensor
)
//
\
get_tensor_model_parallel_world_size
()
start_index
=
partition_size
*
get_tensor_model_parallel_rank
()
end_index
=
start_index
+
partition_size
return
data
[
start_index
:
end_index
]
if
new_buffer
:
data
=
torch
.
empty
(
partition_size
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
data
.
copy_
(
tensor
.
view
(
-
1
)[
start_index
:
end_index
])
else
:
data
=
tensor
.
view
(
-
1
)[
start_index
:
end_index
]
return
data
def
gather_split_1d_tensor
(
tensor
):
"""Opposite of above function, gather values from model parallel ranks."""
...
...
@@ -250,8 +225,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset.
"""
@
staticmethod
def
forward
(
ctx
,
run_function
,
*
args
):
def
forward
(
ctx
,
run_function
,
distribute_checkpointed_activations
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
distribute_checkpointed_activations
\
=
distribute_checkpointed_activations
# Copy the rng states.
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
...
@@ -263,16 +240,14 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
N
on
e
:
if
distribute_checkpointed_activati
on
s
:
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
args
[
0
].
data
=
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
)
args
[
0
].
data
=
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
.
add
(
args
[
0
].
data
)
args
[
0
].
data
=
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
,
new_buffer
=
True
)
# Store everything.
ctx
.
save_for_backward
(
*
args
)
return
outputs
@
staticmethod
...
...
@@ -281,7 +256,7 @@ class CheckpointFunction(torch.autograd.Function):
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs
=
ctx
.
saved_tensors
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
N
on
e
:
if
ctx
.
distribute_checkpointed_activati
on
s
:
inputs
[
0
].
data
=
gather_split_1d_tensor
(
inputs
[
0
].
data
)
inputs
[
0
].
data
=
inputs
[
0
].
data
.
view
(
ctx
.
input_0_shape
)
...
...
@@ -310,10 +285,11 @@ class CheckpointFunction(torch.autograd.Function):
torch
.
autograd
.
backward
(
outputs
,
args
)
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
inp
for
inp
in
detached_inputs
)
return
(
None
,)
+
grads
return
(
None
,
None
)
+
grads
def
checkpoint
(
function
,
*
args
):
def
checkpoint
(
function
,
distribute_checkpointed_activations
,
*
args
):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return
CheckpointFunction
.
apply
(
function
,
*
args
)
return
CheckpointFunction
.
apply
(
function
,
distribute_checkpointed_activations
,
*
args
)
megatron/optimizer/__init__.py
View file @
6fd0b406
...
...
@@ -100,10 +100,12 @@ def get_megatron_optimizer(model):
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
,
args
.
bf16
,
grad_scaler
)
# FP32.
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
)
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
)
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment