Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6fd0b406
Commit
6fd0b406
authored
Nov 25, 2021
by
zihanl
Browse files
merge with main branch
parents
492fdf83
60750922
Changes
92
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
977 additions
and
344 deletions
+977
-344
megatron/fused_kernels/tests/__init__.py
megatron/fused_kernels/tests/__init__.py
+0
-0
megatron/fused_kernels/tests/test_fused_kernels.py
megatron/fused_kernels/tests/test_fused_kernels.py
+300
-0
megatron/global_vars.py
megatron/global_vars.py
+13
-0
megatron/initialize.py
megatron/initialize.py
+31
-19
megatron/model/__init__.py
megatron/model/__init__.py
+1
-0
megatron/model/enums.py
megatron/model/enums.py
+4
-0
megatron/model/fused_bias_gelu.py
megatron/model/fused_bias_gelu.py
+0
-4
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+24
-3
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+73
-45
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+3
-15
megatron/model/language_model.py
megatron/model/language_model.py
+122
-62
megatron/model/module.py
megatron/model/module.py
+33
-8
megatron/model/t5_model.py
megatron/model/t5_model.py
+58
-26
megatron/model/transformer.py
megatron/model/transformer.py
+120
-98
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+4
-2
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+109
-4
megatron/mpu/layers.py
megatron/mpu/layers.py
+55
-9
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+1
-1
megatron/mpu/random.py
megatron/mpu/random.py
+23
-47
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+3
-1
No files found.
megatron/fused_kernels/tests/__init__.py
0 → 100644
View file @
6fd0b406
megatron/fused_kernels/tests/test_fused_kernels.py
0 → 100644
View file @
6fd0b406
import
math
import
torch
from
torch.nn
import
LayerNorm
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.fused_layer_norm
import
MixedFusedLayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.utils
import
attention_mask_func
def
test_load_fused_kernels
():
try
:
import
fused_mix_prec_layer_norm_cuda
import
scaled_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
import
torch
print
(
"[Success] load_fused_kernels"
)
except
ImportError
as
e
:
print
(
"[Fail] load_fused_kernels"
)
raise
e
def
test_fused_softmax
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi"
# 32
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
embedding_output
=
bert
.
embeddings
(
input_ids
=
tokens
[
"input_ids"
].
cuda
(),
position_ids
=
None
,
token_type_ids
=
tokens
[
"token_type_ids"
].
cuda
(),
inputs_embeds
=
None
,
past_key_values_length
=
0
,
)
# (bsz, 1, 1, seq_len)
mask
=
bert
.
get_extended_attention_mask
(
attention_mask
=
tokens
[
"attention_mask"
].
cuda
(),
input_shape
=
tokens
[
"input_ids"
].
shape
,
device
=
bert
.
device
,
)
# (bsz, 1, seq_len, seq_len)
mask
=
mask
.
repeat
(
1
,
1
,
mask
.
size
()[
-
1
],
1
)
attention
=
bert
.
encoder
.
layer
[
0
].
attention
.
self
key_layer
=
attention
.
transpose_for_scores
(
attention
.
key
(
embedding_output
))
query_layer
=
attention
.
transpose_for_scores
(
attention
.
query
(
embedding_output
))
attention_scores
=
torch
.
matmul
(
query_layer
,
key_layer
.
transpose
(
-
1
,
-
2
))
attention_scores
/=
math
.
sqrt
(
key_layer
.
size
()[
-
1
])
fused_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
padding
,
scaled_masked_softmax_fusion
=
True
,
)
.
cuda
()
.
half
()
)
fused_softmax_output
=
fused_softmax
(
attention_scores
,
(
mask
!=
0
),
)
torch_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
padding
,
scaled_masked_softmax_fusion
=
False
,
)
.
cuda
()
.
half
()
)
torch_softmax_output
=
torch_softmax
(
attention_scores
,
(
mask
!=
0
),
)
test_result
=
(
fused_softmax_output
-
torch_softmax_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_fused_softmax"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_fused_softmax"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
test_fused_upper_triangle_mask_softmax
():
gpt
=
GPT2Model
.
from_pretrained
(
"gpt2"
).
cuda
().
half
()
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi"
# 24
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
attention_mask
=
tokens
[
"attention_mask"
].
cuda
()
attention_mask
=
attention_mask
.
view
(
attention_mask
.
size
(
0
),
-
1
)
attention_mask
=
attention_mask
[:,
None
,
None
,
:]
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
attention_mask
=
attention_mask
.
repeat
(
1
,
1
,
attention_mask
.
size
()[
-
1
],
1
)
attn
=
gpt
.
h
[
0
]
hidden_states
=
gpt
.
wte
(
tokens
[
"input_ids"
].
cuda
())
q
,
k
,
v
=
attn
.
attn
.
c_attn
(
hidden_states
).
split
(
768
,
dim
=-
1
)
q
=
attn
.
attn
.
_split_heads
(
q
,
attn
.
attn
.
num_heads
,
attn
.
attn
.
head_dim
)
k
=
attn
.
attn
.
_split_heads
(
k
,
attn
.
attn
.
num_heads
,
attn
.
attn
.
head_dim
)
attn_weights
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
sq
,
sk
=
q
.
size
(
-
2
),
k
.
size
(
-
2
)
causal_mask
=
attn
.
attn
.
bias
[:,
:,
sk
-
sq
:
sk
,
:
sk
].
bool
()
total_mask
=
~
(
causal_mask
&
(
attention_mask
==
0
))
"""
tensor([[[[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, True, True],
[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]]]
"""
fused_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
scaled_masked_softmax_fusion
=
True
,
)
.
cuda
()
.
half
()
)
fused_softmax_output
=
fused_softmax
(
attn_weights
,
total_mask
,
)
torch_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
scaled_masked_softmax_fusion
=
False
,
)
.
cuda
()
.
half
()
)
torch_softmax_output
=
torch_softmax
(
attn_weights
,
total_mask
,
)
test_result
=
(
fused_softmax_output
-
torch_softmax_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_fused_upper_triangle_mask_softmax"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_fused_upper_triangle_mask_softmax"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
test_layer_norm
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi"
# 32
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
# [bsz, seq_len, d_model]
embedding_output
=
(
bert
.
embeddings
(
input_ids
=
tokens
[
"input_ids"
].
cuda
(),
position_ids
=
None
,
token_type_ids
=
tokens
[
"token_type_ids"
].
cuda
(),
inputs_embeds
=
None
,
past_key_values_length
=
0
,
)
.
cuda
()
.
half
()
)
fused_layernorm_layer
=
(
MixedFusedLayerNorm
(
normalized_shape
=
embedding_output
.
size
(
-
1
)).
cuda
().
half
()
)
torch_layernorm_layer
=
(
LayerNorm
(
normalized_shape
=
embedding_output
.
size
(
-
1
)).
cuda
().
half
()
)
fused_output
=
fused_layernorm_layer
(
embedding_output
)
torch_output
=
torch_layernorm_layer
(
embedding_output
)
test_result
=
(
fused_output
-
torch_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_layer_norm"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_layer_norm"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
if
__name__
==
"__main__"
:
try
:
from
transformers
import
BertTokenizer
,
GPT2Tokenizer
from
transformers.models.bert.modeling_bert
import
BertModel
from
transformers.models.gpt2.modeling_gpt2
import
GPT2Model
import
transformers
transformers
.
logging
.
set_verbosity
(
transformers
.
logging
.
FATAL
,
)
except
:
print
(
"
\n
[Fail] Please install `transformers` package to test fused kernels
\n
"
)
exit
(
-
1
)
test_load_fused_kernels
()
test_fused_softmax
()
test_fused_upper_triangle_mask_softmax
()
test_layer_norm
()
megatron/global_vars.py
View file @
6fd0b406
...
@@ -21,6 +21,7 @@ import time
...
@@ -21,6 +21,7 @@ import time
import
torch
import
torch
from
megatron
import
dist_signal_handler
from
megatron.tokenizer
import
build_tokenizer
from
megatron.tokenizer
import
build_tokenizer
from
.arguments
import
parse_args
from
.arguments
import
parse_args
from
.microbatches
import
build_num_microbatches_calculator
from
.microbatches
import
build_num_microbatches_calculator
...
@@ -31,6 +32,7 @@ _GLOBAL_TOKENIZER = None
...
@@ -31,6 +32,7 @@ _GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER
=
None
_GLOBAL_TENSORBOARD_WRITER
=
None
_GLOBAL_ADLR_AUTORESUME
=
None
_GLOBAL_ADLR_AUTORESUME
=
None
_GLOBAL_TIMERS
=
None
_GLOBAL_TIMERS
=
None
_GLOBAL_SIGNAL_HANDLER
=
None
def
get_args
():
def
get_args
():
...
@@ -75,6 +77,14 @@ def get_timers():
...
@@ -75,6 +77,14 @@ def get_timers():
_ensure_var_is_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
_ensure_var_is_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
return
_GLOBAL_TIMERS
return
_GLOBAL_TIMERS
def
get_signal_handler
():
_ensure_var_is_initialized
(
_GLOBAL_SIGNAL_HANDLER
,
'signal handler'
)
return
_GLOBAL_SIGNAL_HANDLER
def
_set_signal_handler
():
global
_GLOBAL_SIGNAL_HANDLER
_ensure_var_is_not_initialized
(
_GLOBAL_SIGNAL_HANDLER
,
'signal handler'
)
_GLOBAL_SIGNAL_HANDLER
=
dist_signal_handler
.
DistributedSignalHandler
().
__enter__
()
def
set_global_variables
(
extra_args_provider
=
None
,
args_defaults
=
{},
def
set_global_variables
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
):
ignore_unknown_args
=
False
):
...
@@ -89,6 +99,9 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
...
@@ -89,6 +99,9 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
_set_adlr_autoresume
(
args
)
_set_adlr_autoresume
(
args
)
_set_timers
()
_set_timers
()
if
args
.
exit_signal_handler
:
_set_signal_handler
()
def
_parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
def
_parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
ignore_unknown_args
=
False
):
ignore_unknown_args
=
False
):
...
...
megatron/initialize.py
View file @
6fd0b406
...
@@ -21,6 +21,7 @@ import time
...
@@ -21,6 +21,7 @@ import time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
datetime
import
timedelta
from
megatron
import
fused_kernels
from
megatron
import
fused_kernels
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_adlr_autoresume
...
@@ -63,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -63,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
print
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
print
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
_set_random_seed
(
args
.
seed
)
_set_random_seed
(
args
.
seed
)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options
()
args
=
get_args
()
args
=
get_args
()
if
args
.
lazy_mpu_init
:
if
args
.
lazy_mpu_init
:
args
.
use_cpu_initialization
=
True
args
.
use_cpu_initialization
=
True
...
@@ -77,9 +81,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -77,9 +81,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Megatron's MPU is the master. Complete initialization right away.
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init
()
finish_mpu_init
()
# Initialize memory buffers.
_initialize_mem_buffs
()
# Autoresume.
# Autoresume.
_init_autoresume
()
_init_autoresume
()
...
@@ -175,15 +176,11 @@ def _initialize_distributed():
...
@@ -175,15 +176,11 @@ def _initialize_distributed():
else
:
else
:
args
.
local_rank
=
device
args
.
local_rank
=
device
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
# Call the init process
# Call the init process
init_method
=
'tcp://'
torch
.
distributed
.
init_process_group
(
master_ip
=
os
.
getenv
(
'MASTER_ADDR'
,
'localhost'
)
backend
=
args
.
distributed_backend
,
master_port
=
os
.
getenv
(
'MASTER_PORT'
,
'6000'
)
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
init_method
+=
master_ip
+
':'
+
master_port
timeout
=
timedelta
(
minutes
=
10
))
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
init_method
=
init_method
)
# Set the tensor model-parallel, pipeline model-parallel, and
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
# data-parallel communicators.
...
@@ -193,7 +190,8 @@ def _initialize_distributed():
...
@@ -193,7 +190,8 @@ def _initialize_distributed():
else
:
else
:
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
)
args
.
virtual_pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_split_rank
)
def
_init_autoresume
():
def
_init_autoresume
():
...
@@ -229,10 +227,24 @@ def write_args_to_tensorboard():
...
@@ -229,10 +227,24 @@ def write_args_to_tensorboard():
global_step
=
args
.
iteration
)
global_step
=
args
.
iteration
)
def
_initialize_mem_buffs
():
def
_set_jit_fusion_options
():
"""Initialize manually allocated static memory."""
"""Set PyTorch JIT layer fusion options."""
args
=
get_args
()
# flags required to enable jit fusion kernels
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
10
):
# nvfuser
torch
.
_C
.
_jit_set_profiling_executor
(
True
)
torch
.
_C
.
_jit_set_profiling_mode
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
False
)
torch
.
_C
.
_jit_set_texpr_fuser_enabled
(
False
)
torch
.
_C
.
_jit_set_nvfuser_enabled
(
True
)
torch
.
_C
.
_debug_set_autodiff_subgraph_inlining
(
False
)
else
:
# legacy pytorch fuser
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
# Initialize memory for checkpointed activations.
if
args
.
distribute_checkpointed_activations
:
mpu
.
init_checkpointed_activations_memory_buffer
()
megatron/model/__init__.py
View file @
6fd0b406
...
@@ -21,3 +21,4 @@ from .gpt_model import GPTModel
...
@@ -21,3 +21,4 @@ from .gpt_model import GPTModel
from
.t5_model
import
T5Model
from
.t5_model
import
T5Model
from
.language_model
import
get_language_model
from
.language_model
import
get_language_model
from
.module
import
Float16Module
from
.module
import
Float16Module
from
.enums
import
ModelType
megatron/model/enums.py
View file @
6fd0b406
...
@@ -15,6 +15,10 @@
...
@@ -15,6 +15,10 @@
import
enum
import
enum
class
ModelType
(
enum
.
Enum
):
encoder_or_decoder
=
1
encoder_and_decoder
=
2
class
LayerType
(
enum
.
Enum
):
class
LayerType
(
enum
.
Enum
):
encoder
=
1
encoder
=
1
decoder
=
2
decoder
=
2
...
...
megatron/model/fused_bias_gelu.py
View file @
6fd0b406
...
@@ -15,10 +15,6 @@
...
@@ -15,10 +15,6 @@
import
torch
import
torch
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
###### BIAS GELU FUSION/ NO AUTOGRAD ################
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2*pi)-> 0.3989423
...
...
megatron/model/fused_layer_norm.py
View file @
6fd0b406
...
@@ -23,6 +23,12 @@ from torch.nn.parameter import Parameter
...
@@ -23,6 +23,12 @@ from torch.nn.parameter import Parameter
from
torch.nn
import
init
from
torch.nn
import
init
import
importlib
import
importlib
try
:
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNormFN
HAVE_PERSIST_LAYER_NORM
=
True
except
:
HAVE_PERSIST_LAYER_NORM
=
False
global
fused_mix_prec_layer_norm_cuda
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
None
fused_mix_prec_layer_norm_cuda
=
None
...
@@ -61,13 +67,23 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
...
@@ -61,13 +67,23 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
no_persist_layer_norm
=
True
):
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
global
fused_mix_prec_layer_norm_cuda
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
"fused_mix_prec_layer_norm_cuda"
)
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# kernel.
persist_ln_hidden_sizes
=
[
1024
,
1536
,
2048
,
2304
,
3072
,
3840
,
4096
,
5120
,
6144
,
8192
,
10240
,
12288
,
12800
,
15360
,
16384
,
18432
,
20480
,
24576
,
25600
,
30720
,
32768
,
40960
,
49152
,
65536
]
if
normalized_shape
not
in
persist_ln_hidden_sizes
or
\
not
HAVE_PERSIST_LAYER_NORM
:
no_persist_layer_norm
=
True
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
normalized_shape
=
(
normalized_shape
,)
normalized_shape
=
(
normalized_shape
,)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
...
@@ -75,6 +91,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
...
@@ -75,6 +91,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
reset_parameters
()
self
.
reset_parameters
()
self
.
no_persist_layer_norm
=
no_persist_layer_norm
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
...
@@ -85,6 +102,10 @@ class MixedFusedLayerNorm(torch.nn.Module):
...
@@ -85,6 +102,10 @@ class MixedFusedLayerNorm(torch.nn.Module):
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
return
FusedLayerNormAffineFunction
.
apply
(
if
self
.
no_persist_layer_norm
:
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
return
FastLayerNormFN
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
megatron/model/fused_softmax.py
View file @
6fd0b406
...
@@ -13,7 +13,9 @@
...
@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
import
torch
import
torch.nn
as
nn
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.enums
import
AttnMaskType
...
@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
...
@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import
scaled_upper_triang_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_upper_triang_masked_softmax_cuda
.
forward
(
softmax_results
=
scaled_upper_triang_masked_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
]
inputs
,
scale_t
[
0
]
)
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
return
softmax_results
...
@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
...
@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import
scaled_upper_triang_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_upper_triang_masked_softmax_cuda
.
backward
(
input_grads
=
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
output_grads
,
softmax_results
,
scale_t
[
0
]
)
)
return
input_grads
,
None
return
input_grads
,
None
...
@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
...
@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
scale_t
=
torch
.
tensor
([
scale
])
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
mask
,
scale_t
[
0
])
inputs
,
mask
,
scale_t
[
0
]
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
return
softmax_results
...
@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
...
@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return
input_grads
,
None
,
None
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
class
FusedScaleMaskSoftmax
(
nn
.
Module
):
"""
"""
fused operation: scaling + mask + softmax
fused operation: scaling + mask + softmax
Arguments:
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
scale: scaling factor used in input tensor scaling.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
\
assert
not
(
'both fp16 and bf16 flags cannot be active at the same time.'
self
.
input_in_fp16
and
self
.
input_in_bf16
),
"both fp16 and bf16 flags cannot be active at the same time."
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
...
@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert
(
assert
(
self
.
scale
is
None
or
softmax_in_fp32
self
.
scale
is
None
or
softmax_in_fp32
),
"softmax should be in fp32 when scaled"
),
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
def
forward
(
self
,
input
,
mask
):
# [b, np, sq, sk]
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
assert
input
.
dim
()
==
4
data_size
=
input
.
size
()
query_seq_len
=
data_size
[
-
2
]
if
self
.
is_kernel_available
(
mask
,
*
input
.
size
()):
key_seq_len
=
data_size
[
-
1
]
return
self
.
forward_fused_softmax
(
input
,
mask
)
attn_batch_size
=
data_size
[
0
]
*
data_size
[
1
]
# constraints on various tensor dimensions to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
key_seq_len
>
16
and
key_seq_len
<=
2048
and
\
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# invoke custom kernel
if
self
.
input_in_float16
and
mask
is
not
None
and
\
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
query_seq_len
==
key_seq_len
,
\
"causal mask is only for self attention"
input
=
input
.
view
(
-
1
,
query_seq_len
,
key_seq_len
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
probs
.
view
(
*
data_size
)
else
:
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
else
:
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
return
self
.
forward_torch_softmax
(
input
,
mask
)
input
=
input
.
float
()
def
is_kernel_available
(
self
,
mask
,
b
,
np
,
sq
,
sk
):
attn_batches
=
b
*
np
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
mask
is
not
None
# mask tensor must not be None
and
16
<
sk
<=
2048
# sk must be 16 ~ 2048
and
sq
%
4
==
0
# sq must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
2048
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
attn_batches
%
batch_per_block
==
0
:
return
True
else
:
if
sq
%
batch_per_block
==
0
:
return
True
return
False
if
self
.
scale
is
not
None
:
def
forward_fused_softmax
(
self
,
input
,
mask
):
input
=
input
*
self
.
scale
b
,
np
,
sq
,
sk
=
input
.
size
()
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
self
.
input_in_fp16
:
assert
sq
==
sk
,
"causal mask is only for self attention"
probs
=
probs
.
half
()
else
:
# input is 3D tensor (attn_batches, sq, sk)
probs
=
probs
.
bfloat16
()
input
=
input
.
view
(
-
1
,
sq
,
sk
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
return
probs
.
view
(
b
,
np
,
sq
,
sk
)
else
:
# input is 4D tensor (b, np, sq, sk)
return
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
def
forward_torch_softmax
(
self
,
input
,
mask
):
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
return
probs
@
staticmethod
def
get_batch_per_block
(
sq
,
sk
,
b
,
np
):
import
scaled_masked_softmax_cuda
return
scaled_masked_softmax_cuda
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
megatron/model/gpt_model.py
View file @
6fd0b406
...
@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal
...
@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal
def
post_language_model_processing
(
lm_output
,
labels
,
logit_weights
,
def
post_language_model_processing
(
lm_output
,
labels
,
logit_weights
,
get_key_value
,
parallel_output
,
parallel_output
,
forward_method_parallel_output
,
fp16_lm_cross_entropy
):
fp16_lm_cross_entropy
):
if
get_key_value
:
lm_output
,
presents
=
lm_output
# Output.
# Output.
if
forward_method_parallel_output
is
not
None
:
parallel_output
=
forward_method_parallel_output
output
=
parallel_lm_logits
(
output
=
parallel_lm_logits
(
lm_output
,
lm_output
,
logit_weights
,
logit_weights
,
parallel_output
)
parallel_output
)
if
get_key_value
:
output
=
[
output
,
presents
]
if
labels
is
None
:
if
labels
is
None
:
return
output
return
output
else
:
else
:
...
@@ -90,23 +82,19 @@ class GPTModel(MegatronModule):
...
@@ -90,23 +82,19 @@ class GPTModel(MegatronModule):
self
.
language_model
.
set_input_tensor
(
input_tensor
)
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
tokentype_ids
=
None
,
inference_params
=
None
):
forward_method_parallel_output
=
None
):
lm_output
=
self
.
language_model
(
lm_output
=
self
.
language_model
(
input_ids
,
input_ids
,
position_ids
,
position_ids
,
attention_mask
,
attention_mask
,
layer_past
=
layer_past
,
inference_params
=
inference_params
)
get_key_value
=
get_key_value
)
if
self
.
post_process
:
if
self
.
post_process
:
return
post_language_model_processing
(
return
post_language_model_processing
(
lm_output
,
labels
,
lm_output
,
labels
,
self
.
word_embeddings_weight
(),
self
.
word_embeddings_weight
(),
get_key_value
,
self
.
parallel_output
,
self
.
parallel_output
,
forward_method_parallel_output
,
self
.
fp16_lm_cross_entropy
)
self
.
fp16_lm_cross_entropy
)
else
:
else
:
return
lm_output
return
lm_output
...
...
megatron/model/language_model.py
View file @
6fd0b406
...
@@ -45,7 +45,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -45,7 +45,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def
get_language_model
(
num_tokentypes
,
add_pooler
,
def
get_language_model
(
num_tokentypes
,
add_pooler
,
encoder_attn_mask_type
,
init_method
=
None
,
encoder_attn_mask_type
,
init_method
=
None
,
scaled_init_method
=
None
,
add_decoder
=
False
,
scaled_init_method
=
None
,
add_encoder
=
True
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
pre_process
=
True
,
post_process
=
True
):
pre_process
=
True
,
post_process
=
True
):
"""Build language model and return along with the key to save."""
"""Build language model and return along with the key to save."""
...
@@ -64,6 +65,7 @@ def get_language_model(num_tokentypes, add_pooler,
...
@@ -64,6 +65,7 @@ def get_language_model(num_tokentypes, add_pooler,
scaled_init_method
,
scaled_init_method
,
encoder_attn_mask_type
,
encoder_attn_mask_type
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
,
add_decoder
=
add_decoder
,
decoder_attn_mask_type
=
decoder_attn_mask_type
,
decoder_attn_mask_type
=
decoder_attn_mask_type
,
add_pooler
=
add_pooler
,
add_pooler
=
add_pooler
,
...
@@ -159,6 +161,16 @@ class Embedding(MegatronModule):
...
@@ -159,6 +161,16 @@ class Embedding(MegatronModule):
# Embeddings dropout
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
def
zero_parameters
(
self
):
"""Zero out all parameters in embedding."""
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
self
.
position_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
position_embeddings
.
weight
.
shared
=
True
if
self
.
num_tokentypes
>
0
:
self
.
tokentype_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
tokentype_embeddings
.
weight
.
shared
=
True
def
add_tokentype_embeddings
(
self
,
num_tokentypes
):
def
add_tokentype_embeddings
(
self
,
num_tokentypes
):
"""Add token-type embedding. This function is provided so we can add
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
token-type embeddings in case the pretrained model does not have it.
...
@@ -273,6 +285,7 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -273,6 +285,7 @@ class TransformerLanguageModel(MegatronModule):
output_layer_init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
encoder_attn_mask_type
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
add_encoder
=
True
,
add_decoder
=
False
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
add_pooler
=
False
,
add_pooler
=
False
,
...
@@ -286,10 +299,12 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -286,10 +299,12 @@ class TransformerLanguageModel(MegatronModule):
self
.
hidden_size
=
args
.
hidden_size
self
.
hidden_size
=
args
.
hidden_size
self
.
num_tokentypes
=
num_tokentypes
self
.
num_tokentypes
=
num_tokentypes
self
.
init_method
=
init_method
self
.
init_method
=
init_method
self
.
add_encoder
=
add_encoder
self
.
encoder_attn_mask_type
=
encoder_attn_mask_type
self
.
encoder_attn_mask_type
=
encoder_attn_mask_type
self
.
add_decoder
=
add_decoder
self
.
add_decoder
=
add_decoder
self
.
decoder_attn_mask_type
=
decoder_attn_mask_type
self
.
decoder_attn_mask_type
=
decoder_attn_mask_type
self
.
add_pooler
=
add_pooler
self
.
add_pooler
=
add_pooler
self
.
encoder_hidden_state
=
None
# Embeddings.
# Embeddings.
if
self
.
pre_process
:
if
self
.
pre_process
:
...
@@ -302,25 +317,37 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -302,25 +317,37 @@ class TransformerLanguageModel(MegatronModule):
self
.
_embedding_key
=
'embedding'
self
.
_embedding_key
=
'embedding'
# Transformer.
# Transformer.
self
.
encoder
=
ParallelTransformer
(
# Encoder (usually set to True, False if part of an encoder-decoder
self
.
init_method
,
# architecture and in encoder-only stage).
output_layer_init_method
,
if
self
.
add_encoder
:
self_attn_mask_type
=
self
.
encoder_attn_mask_type
,
self
.
encoder
=
ParallelTransformer
(
pre_process
=
self
.
pre_process
,
self
.
init_method
,
post_process
=
self
.
post_process
output_layer_init_method
,
)
self_attn_mask_type
=
self
.
encoder_attn_mask_type
,
self
.
_encoder_key
=
'encoder'
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
# Decoder
)
self
.
_encoder_key
=
'encoder'
else
:
self
.
encoder
=
None
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if
self
.
add_decoder
:
if
self
.
add_decoder
:
# Temporary assertion until we verify correctness of pipeline parallelism
# implementation of T5.
assert
args
.
pipeline_model_parallel_size
==
1
,
\
assert
args
.
pipeline_model_parallel_size
==
1
,
\
'pipeline parallelism is not supported in the presence of decoder'
'pipeline parallelism is not supported in the presence of decoder'
self
.
decoder
=
ParallelTransformer
(
self
.
decoder
=
ParallelTransformer
(
self
.
init_method
,
self
.
init_method
,
output_layer_init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
decoder
,
layer_type
=
LayerType
.
decoder
,
self_attn_mask_type
=
self
.
decoder_attn_mask_type
)
self_attn_mask_type
=
self
.
decoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
_decoder_key
=
'decoder'
self
.
_decoder_key
=
'decoder'
else
:
self
.
decoder
=
None
if
self
.
post_process
:
if
self
.
post_process
:
# Pooler.
# Pooler.
...
@@ -330,28 +357,55 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -330,28 +357,55 @@ class TransformerLanguageModel(MegatronModule):
def
set_input_tensor
(
self
,
input_tensor
):
def
set_input_tensor
(
self
,
input_tensor
):
""" See megatron.model.transformer.set_input_tensor()"""
""" See megatron.model.transformer.set_input_tensor()"""
self
.
encoder
.
set_input_tensor
(
input_tensor
)
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
if
self
.
add_encoder
and
self
.
add_decoder
:
assert
len
(
input_tensor
)
==
1
,
\
'input_tensor should only be length 1 for stage with both encoder and decoder'
self
.
encoder
.
set_input_tensor
(
input_tensor
[
0
])
elif
self
.
add_encoder
:
assert
len
(
input_tensor
)
==
1
,
\
'input_tensor should only be length 1 for stage with only encoder'
self
.
encoder
.
set_input_tensor
(
input_tensor
[
0
])
elif
self
.
add_decoder
:
if
len
(
input_tensor
)
==
2
:
self
.
decoder
.
set_input_tensor
(
input_tensor
[
0
])
self
.
encoder_hidden_state
=
input_tensor
[
1
]
elif
len
(
input_tensor
)
==
1
:
self
.
decoder
.
set_input_tensor
(
None
)
self
.
encoder_hidden_state
=
input_tensor
[
0
]
else
:
raise
Exception
(
'input_tensor must have either length 1 or 2'
)
else
:
raise
Exception
(
'Stage must have at least either encoder or decoder'
)
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
inference_params
=
None
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
# Embedding
s
.
# E
ncoder e
mbedding.
if
self
.
pre_process
:
if
self
.
pre_process
:
embedding_output
=
self
.
embedding
(
enc_input_ids
,
enc_position_ids
,
encoder_input
=
self
.
embedding
(
enc_input_ids
,
enc_position_ids
,
tokentype_ids
=
tokentype_ids
)
tokentype_ids
=
tokentype_ids
)
encoder_input
=
embedding_output
else
:
else
:
encoder_input
=
None
encoder_input
=
None
# encoder.
#
Run
encoder.
if
enc_hidden_states
is
None
:
if
enc_hidden_states
is
None
:
encoder_output
=
self
.
encoder
(
encoder_input
,
if
self
.
encoder
is
not
None
:
enc_attn_mask
,
encoder_output
=
self
.
encoder
(
layer_past
=
layer_past
,
encoder_input
,
get_key_value
=
get_key_value
)
enc_attn_mask
,
inference_params
=
inference_params
)
else
:
encoder_output
=
self
.
encoder_hidden_state
else
:
else
:
encoder_output
=
enc_hidden_states
.
to
(
encoder_input
.
dtype
)
encoder_output
=
enc_hidden_states
.
to
(
encoder_input
.
dtype
)
...
@@ -369,16 +423,20 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -369,16 +423,20 @@ class TransformerLanguageModel(MegatronModule):
else
:
else
:
return
encoder_output
return
encoder_output
# Decoder Embedding
# Decoder embedding.
dec_embedding_output
=
self
.
embedding
(
dec_input_ids
,
if
self
.
pre_process
:
dec_position_ids
)
decoder_input
=
self
.
embedding
(
dec_input_ids
,
# decoder
dec_position_ids
)
decoder_output
=
self
.
decoder
(
dec_embedding_output
,
else
:
dec_attn_mask
,
decoder_input
=
None
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
# Run decoder.
encoder_output
=
encoder_output
,
decoder_output
=
self
.
decoder
(
enc_dec_attn_mask
=
enc_dec_attn_mask
)
decoder_input
,
dec_attn_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
if
self
.
add_pooler
and
self
.
post_process
:
if
self
.
add_pooler
and
self
.
post_process
:
return
decoder_output
,
encoder_output
,
pooled_output
return
decoder_output
,
encoder_output
,
pooled_output
...
@@ -394,9 +452,10 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -394,9 +452,10 @@ class TransformerLanguageModel(MegatronModule):
state_dict_
[
self
.
_embedding_key
]
\
state_dict_
[
self
.
_embedding_key
]
\
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_encoder_key
]
\
if
self
.
add_encoder
:
=
self
.
encoder
.
state_dict_for_save_checkpoint
(
state_dict_
[
self
.
_encoder_key
]
\
destination
,
prefix
,
keep_vars
)
=
self
.
encoder
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
post_process
:
if
self
.
post_process
:
if
self
.
add_pooler
:
if
self
.
add_pooler
:
state_dict_
[
self
.
_pooler_key
]
\
state_dict_
[
self
.
_pooler_key
]
\
...
@@ -425,38 +484,39 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -425,38 +484,39 @@ class TransformerLanguageModel(MegatronModule):
self
.
embedding
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
self
.
embedding
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Encoder.
# Encoder.
if
self
.
_encoder_key
in
state_dict
:
if
self
.
add_encoder
:
state_dict_
=
state_dict
[
self
.
_encoder_key
]
if
self
.
_encoder_key
in
state_dict
:
# for backward compatibility.
state_dict_
=
state_dict
[
self
.
_encoder_key
]
elif
'transformer'
in
state_dict
:
# For backward compatibility.
state_dict_
=
state_dict
[
'transformer'
]
elif
'transformer'
in
state_dict
:
else
:
state_dict_
=
state_dict
[
'transformer'
]
# for backward compatibility.
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
if
'transformer.'
in
key
:
state_dict_
[
key
.
split
(
'transformer.'
)[
1
]]
=
state_dict
[
key
]
# for backward compatibility.
state_dict_self_attention
=
{}
for
key
in
state_dict_
.
keys
():
if
'.attention.'
in
key
:
state_dict_self_attention
[
key
.
replace
(
".attention."
,
".self_attention."
)]
=
state_dict_
[
key
]
else
:
else
:
state_dict_self_attention
[
key
]
=
state_dict_
[
key
]
# For backward compatibility.
state_dict_
=
state_dict_self_attention
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
self
.
encoder
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
if
'transformer.'
in
key
:
state_dict_
[
key
.
split
(
'transformer.'
)[
1
]]
=
state_dict
[
key
]
# For backward compatibility.
state_dict_self_attention
=
{}
for
key
in
state_dict_
.
keys
():
if
'.attention.'
in
key
:
state_dict_self_attention
[
key
.
replace
(
".attention."
,
".self_attention."
)]
=
state_dict_
[
key
]
else
:
state_dict_self_attention
[
key
]
=
state_dict_
[
key
]
state_dict_
=
state_dict_self_attention
self
.
encoder
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
# Pooler.
if
self
.
post_process
:
if
self
.
post_process
:
# pooler
if
self
.
add_pooler
:
if
self
.
add_pooler
:
assert
'pooler'
in
state_dict
,
\
assert
'pooler'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
'could not find data for pooler in the checkpoint'
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
strict
=
strict
)
strict
=
strict
)
#
d
ecoder
#
D
ecoder
.
if
self
.
add_decoder
:
if
self
.
add_decoder
:
assert
'decoder'
in
state_dict
,
\
assert
'decoder'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
'could not find data for pooler in the checkpoint'
...
...
megatron/model/module.py
View file @
6fd0b406
...
@@ -51,15 +51,14 @@ class MegatronModule(torch.nn.Module):
...
@@ -51,15 +51,14 @@ class MegatronModule(torch.nn.Module):
def
word_embeddings_weight
(
self
):
def
word_embeddings_weight
(
self
):
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
not
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
or
\
mpu
.
get_pipeline_model_parallel_world_size
()
==
1
:
return
self
.
language_model
.
embedding
.
word_embeddings
.
weight
return
self
.
language_model
.
embedding
.
word_embeddings
.
weight
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
:
else
:
if
not
self
.
share_word_embeddings
:
if
not
self
.
share_word_embeddings
:
raise
Exception
(
'word_embeddings_weight() called for last '
raise
Exception
(
'word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false'
)
'stage, but share_word_embeddings is false'
)
return
self
.
word_embeddings
.
weight
return
self
.
word_embeddings
.
weight
raise
Exception
(
'word_embeddings_weight() should be '
'called for first and last stage only'
)
def
initialize_word_embeddings
(
self
,
init_method_normal
):
def
initialize_word_embeddings
(
self
,
init_method_normal
):
...
@@ -69,12 +68,12 @@ class MegatronModule(torch.nn.Module):
...
@@ -69,12 +68,12 @@ class MegatronModule(torch.nn.Module):
'share_word_embeddings is false'
)
'share_word_embeddings is false'
)
# This function just initializes the word embeddings in the final stage
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism.
If we aren't using pipeline
# when we are using pipeline parallelism.
Nothing to do if we aren't
#
parallelism there is nothing to do
.
#
using pipeline parallelism
.
if
args
.
pipeline_model_parallel_size
==
1
:
if
args
.
pipeline_model_parallel_size
==
1
:
return
return
# Parameters are shared between the word embeddings layer, and the
# Parameters are shared between the word embeddings layer
s
, and the
# heads at the end of the model. In a pipelined setup with more than
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# workers, so we do the following:
...
@@ -97,12 +96,34 @@ class MegatronModule(torch.nn.Module):
...
@@ -97,12 +96,34 @@ class MegatronModule(torch.nn.Module):
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
self
.
word_embeddings
.
weight
.
shared
=
True
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if
not
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
and
\
not
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
\
mpu
.
is_rank_in_embedding_group
():
self
.
language_model
.
embedding
.
zero_parameters
()
# Ensure that first and last stages have the same initial parameter
# Ensure that first and last stages have the same initial parameter
# values.
# values.
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
is_initialized
():
if
mpu
.
is_
pipeline_first_stage
()
or
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_
rank_in_embedding_group
():
torch
.
distributed
.
all_reduce
(
self
.
word_embeddings_weight
().
data
,
torch
.
distributed
.
all_reduce
(
self
.
word_embeddings_weight
().
data
,
group
=
mpu
.
get_embedding_group
())
group
=
mpu
.
get_embedding_group
())
# All-reduce other embeddings as well as necessary. The last stage
# does not have these other embeddings, so just create placeholder
# tensors of the right shape with all zeros.
# NOTE: We don't currently support T5 with the interleaved schedule.
if
args
.
pipeline_model_parallel_split_rank
is
not
None
:
# TODO: Support tokentype embedding.
dimensions
=
(
args
.
max_position_embeddings
,
args
.
hidden_size
)
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
position_embeddings
=
torch
.
nn
.
Embedding
(
*
dimensions
).
cuda
()
position_embeddings
.
weight
.
data
.
fill_
(
0
)
else
:
self
.
language_model
.
embedding
.
cuda
()
position_embeddings
=
self
.
language_model
.
embedding
.
position_embeddings
torch
.
distributed
.
all_reduce
(
position_embeddings
.
weight
.
data
,
group
=
mpu
.
get_embedding_group
())
else
:
else
:
print
(
"WARNING! Distributed processes aren't initialized, so "
print
(
"WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"word embeddings in the last layer are not initialized. "
...
@@ -166,6 +187,10 @@ class Float16Module(MegatronModule):
...
@@ -166,6 +187,10 @@ class Float16Module(MegatronModule):
self
.
float16_convertor
=
float16_convertor
self
.
float16_convertor
=
float16_convertor
def
set_input_tensor
(
self
,
input_tensor
):
return
self
.
module
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
inputs
=
fp32_to_float16
(
inputs
,
self
.
float16_convertor
)
inputs
=
fp32_to_float16
(
inputs
,
self
.
float16_convertor
)
...
...
megatron/model/t5_model.py
View file @
6fd0b406
...
@@ -86,7 +86,13 @@ class T5LMHead(MegatronModule):
...
@@ -86,7 +86,13 @@ class T5LMHead(MegatronModule):
class
T5Model
(
MegatronModule
):
class
T5Model
(
MegatronModule
):
"""T5 Language model."""
"""T5 Language model."""
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
):
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
,
add_encoder
=
True
,
add_decoder
=
True
):
super
(
T5Model
,
self
).
__init__
()
super
(
T5Model
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
...
@@ -95,19 +101,29 @@ class T5Model(MegatronModule):
...
@@ -95,19 +101,29 @@ class T5Model(MegatronModule):
init_method
=
init_method_normal
(
args
.
init_method_std
)
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
args
.
num_layers
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
add_encoder
=
add_encoder
self
.
add_decoder
=
add_decoder
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
add_pooler
=
False
,
add_decoder
=
True
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
scaled_init_method
=
scaled_init_method
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
lm_head
=
T5LMHead
(
self
.
initialize_word_embeddings
(
init_method_normal
)
self
.
language_model
.
embedding
.
word_embeddings
.
weight
.
size
(
0
),
parallel_output
)
if
self
.
post_process
and
self
.
add_decoder
:
self
.
_lm_head_key
=
'lm_head'
self
.
lm_head
=
T5LMHead
(
self
.
word_embeddings_weight
().
size
(
0
),
parallel_output
)
self
.
_lm_head_key
=
'lm_head'
def
set_input_tensor
(
self
,
input_tensor
):
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
"""See megatron.model.transformer.set_input_tensor()"""
...
@@ -134,22 +150,28 @@ class T5Model(MegatronModule):
...
@@ -134,22 +150,28 @@ class T5Model(MegatronModule):
tokentype_ids
=
tokentype_ids
,
tokentype_ids
=
tokentype_ids
,
enc_hidden_states
=
enc_hidden_states
)
enc_hidden_states
=
enc_hidden_states
)
decoder_output
,
encoder_output
=
lm_output
if
self
.
post_process
and
self
.
add_decoder
:
decoder_output
,
encoder_output
=
lm_output
# Output.
# Output.
lm_logits
=
self
.
lm_head
(
decoder_output
,
lm_logits
=
self
.
lm_head
(
decoder_output
,
self
.
language_model
.
embedding
.
word_embeddings
.
weight
)
self
.
word_embeddings
_
weight
()
)
if
lm_labels
is
None
:
if
lm_labels
is
None
:
return
lm_logits
,
encoder_output
return
lm_logits
else
:
if
self
.
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
else
:
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
if
self
.
fp16_lm_cross_entropy
:
lm_labels
)
assert
lm_logits
.
dtype
==
torch
.
half
return
lm_loss
,
encoder_output
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
return
lm_loss
elif
self
.
add_decoder
and
not
self
.
add_encoder
:
decoder_output
,
encoder_output
=
lm_output
return
decoder_output
else
:
encoder_output
=
lm_output
return
encoder_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
...
@@ -160,9 +182,14 @@ class T5Model(MegatronModule):
...
@@ -160,9 +182,14 @@ class T5Model(MegatronModule):
state_dict_
[
self
.
_language_model_key
]
\
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_lm_head_key
]
\
if
self
.
post_process
and
self
.
add_decoder
:
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
state_dict_
[
self
.
_lm_head_key
]
\
destination
,
prefix
,
keep_vars
)
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
# Save word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
and
self
.
add_decoder
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
...
@@ -170,5 +197,10 @@ class T5Model(MegatronModule):
...
@@ -170,5 +197,10 @@ class T5Model(MegatronModule):
self
.
language_model
.
load_state_dict
(
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
self
.
lm_head
.
load_state_dict
(
state_dict
[
self
.
_lm_head_key
],
if
self
.
post_process
and
self
.
add_decoder
:
strict
=
strict
)
self
.
lm_head
.
load_state_dict
(
state_dict
[
self
.
_lm_head_key
],
strict
=
strict
)
# Load word embeddings.
if
self
.
post_process
and
not
self
.
pre_process
and
self
.
add_decoder
:
self
.
word_embeddings
.
load_state_dict
(
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
megatron/model/transformer.py
View file @
6fd0b406
...
@@ -21,17 +21,12 @@ import torch.nn.functional as F
...
@@ -21,17 +21,12 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model.enums
import
AttnMaskType
,
ModelType
,
LayerType
,
AttnType
from
megatron.model
import
LayerNorm
from
megatron.model
import
LayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
# flags required to enable jit fusion kernels
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
""" We use the following notation throughout this file:
""" We use the following notation throughout this file:
h: hidden size
h: hidden size
...
@@ -53,8 +48,7 @@ class ParallelMLP(MegatronModule):
...
@@ -53,8 +48,7 @@ class ParallelMLP(MegatronModule):
MLP will take the input with h hidden state, project it to 4*h
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
state back into h hidden dimension.
applied.
"""
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
...
@@ -84,7 +78,6 @@ class ParallelMLP(MegatronModule):
...
@@ -84,7 +78,6 @@ class ParallelMLP(MegatronModule):
init_method
=
output_layer_init_method
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
skip_bias_add
=
True
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
# [s, b, 4hp]
# [s, b, 4hp]
...
@@ -125,6 +118,7 @@ class ParallelAttention(MegatronModule):
...
@@ -125,6 +118,7 @@ class ParallelAttention(MegatronModule):
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
params_dtype
=
args
.
params_dtype
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
...
@@ -185,10 +179,40 @@ class ParallelAttention(MegatronModule):
...
@@ -185,10 +179,40 @@ class ParallelAttention(MegatronModule):
init_method
=
output_layer_init_method
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
skip_bias_add
=
True
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
,
encoder_output
=
None
):
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
):
return
torch
.
empty
(
inference_max_sequence_len
,
batch_size
,
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
dtype
=
self
.
params_dtype
,
device
=
torch
.
cuda
.
current_device
())
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
inference_params
=
None
):
# hidden_states: [sq, b, h]
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if
inference_params
:
if
self
.
layer_number
not
in
inference_params
.
key_value_memory_dict
:
inf_max_seq_len
=
inference_params
.
max_sequence_len
inf_max_batch_size
=
inference_params
.
max_batch_size
inference_key_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
)
inference_value_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
)
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_key_memory
,
inference_value_memory
)
else
:
inference_key_memory
,
inference_value_memory
=
\
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
# =====================
# =====================
# Query, Key, and Value
# Query, Key, and Value
# =====================
# =====================
...
@@ -229,18 +253,28 @@ class ParallelAttention(MegatronModule):
...
@@ -229,18 +253,28 @@ class ParallelAttention(MegatronModule):
self
.
hidden_size_per_attention_head
)
self
.
hidden_size_per_attention_head
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ==================================
# ==================================
# Adjust key and value for inference
# Adjust key and value for inference
# ==================================
# ==================================
if
layer_past
is
not
None
:
if
inference_params
:
past_key
,
past_value
=
layer_past
batch_start
=
inference_params
.
batch_size_offset
key_layer
=
torch
.
cat
((
past_key
.
type_as
(
key_layer
),
batch_end
=
batch_start
+
key_layer
.
size
(
1
)
key_layer
),
dim
=
0
)
assert
batch_end
<=
inference_key_memory
.
size
(
1
)
value_layer
=
torch
.
cat
((
past_value
.
type_as
(
value_layer
),
sequence_start
=
inference_params
.
sequence_len_offset
value_layer
),
dim
=
0
)
sequence_end
=
sequence_start
+
key_layer
.
size
(
0
)
if
get_key_value
:
assert
sequence_end
<=
inference_key_memory
.
size
(
0
)
present
=
(
key_layer
,
value_layer
)
# Copy key and values.
inference_key_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
key_layer
inference_value_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
value_layer
key_layer
=
inference_key_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
value_layer
=
inference_value_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
# ===================================
# ===================================
# Raw attention scores. [b, np, s, s]
# Raw attention scores. [b, np, s, s]
...
@@ -277,22 +311,6 @@ class ParallelAttention(MegatronModule):
...
@@ -277,22 +311,6 @@ class ParallelAttention(MegatronModule):
# change view to [b, np, sq, sk]
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
if
get_key_value
:
with
torch
.
no_grad
():
if
layer_past
is
not
None
:
attention_mask
=
attention_mask
[
...,
attention_scores
.
size
(
3
)
-
1
,
:
attention_scores
.
size
(
3
)].
unsqueeze
(
2
)
else
:
attention_mask
=
attention_mask
[
...,
:
attention_scores
.
size
(
3
),
:
attention_scores
.
size
(
3
)]
# ===========================
# ===========================
# Attention probs and dropout
# Attention probs and dropout
...
@@ -348,9 +366,6 @@ class ParallelAttention(MegatronModule):
...
@@ -348,9 +366,6 @@ class ParallelAttention(MegatronModule):
output
,
bias
=
self
.
dense
(
context_layer
)
output
,
bias
=
self
.
dense
(
context_layer
)
if
get_key_value
:
output
=
[
output
,
present
]
return
output
,
bias
return
output
,
bias
...
@@ -368,14 +383,18 @@ def get_bias_dropout_add(training):
...
@@ -368,14 +383,18 @@ def get_bias_dropout_add(training):
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
bias_dropout_add_fused_train
(
x
,
bias
,
residual
,
prob
):
def
bias_dropout_add_fused_train
(
x
:
torch
.
Tensor
,
# type: (Tensor, Tensor, Tensor, float) -> Tensor
bias
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
True
)
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
True
)
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
bias_dropout_add_fused_inference
(
x
,
bias
,
residual
,
prob
):
def
bias_dropout_add_fused_inference
(
x
:
torch
.
Tensor
,
# type: (Tensor, Tensor, Tensor, float) -> Tensor
bias
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
False
)
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
False
)
...
@@ -404,7 +423,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -404,7 +423,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the input data.
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# Self attention.
# Self attention.
self
.
self_attention
=
ParallelAttention
(
self
.
self_attention
=
ParallelAttention
(
...
@@ -419,7 +439,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -419,7 +439,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the attention output
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
if
self
.
layer_type
==
LayerType
.
decoder
:
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
self
.
inter_attention
=
ParallelAttention
(
...
@@ -430,7 +451,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -430,7 +451,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the attention output.
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# MLP
# MLP
self
.
mlp
=
ParallelMLP
(
init_method
,
self
.
mlp
=
ParallelMLP
(
init_method
,
...
@@ -438,20 +460,17 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -438,20 +460,17 @@ class ParallelTransformerLayer(MegatronModule):
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
layer_past
=
None
,
get_key_value
=
Fals
e
):
inference_params
=
Non
e
):
# hidden_states: [b, s, h]
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
# Self attention.
attention_output
,
attention_bias
=
\
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
self
.
self_attention
(
attention_mask
,
layernorm_output
,
layer_past
=
layer_past
,
attention_mask
,
get_key_value
=
get_key_value
)
inference_params
=
inference_params
)
if
get_key_value
:
attention_output
,
presents
=
attention_output
# Residual connection.
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
if
self
.
apply_residual_connection_post_layernorm
:
...
@@ -521,9 +540,6 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -521,9 +540,6 @@ class ParallelTransformerLayer(MegatronModule):
residual
,
residual
,
self
.
hidden_dropout
)
self
.
hidden_dropout
)
if
get_key_value
:
output
=
[
output
,
presents
]
return
output
return
output
...
@@ -544,13 +560,13 @@ class ParallelTransformer(MegatronModule):
...
@@ -544,13 +560,13 @@ class ParallelTransformer(MegatronModule):
self
.
input_tensor
=
None
self
.
input_tensor
=
None
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
self
.
checkpoint_activations
=
args
.
checkpoint_activations
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
self
.
checkpoint_num_layers
=
args
.
checkpoint_num_layers
self
.
activations_checkpoint_num_layers
=
args
.
activations_checkpoint_num_layers
self
.
distribute_checkpointed_activations
=
args
.
distribute_checkpointed_activations
# Number of layers.
# Number of layers.
assert
args
.
num_layers
%
mpu
.
get_pipeline_model_parallel_world_size
()
==
0
,
\
self
.
num_layers
=
mpu
.
get_num_layers
(
'num_layers must be divisible by pipeline_model_parallel_size'
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
self
.
num_layers
=
args
.
num_layers
//
mpu
.
get_pipeline_model_parallel_world_size
()
# Transformer layers.
# Transformer layers.
def
build_layer
(
layer_number
):
def
build_layer
(
layer_number
):
...
@@ -589,7 +605,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -589,7 +605,8 @@ class ParallelTransformer(MegatronModule):
# Final layer norm before output.
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
def
_get_layer
(
self
,
layer_number
):
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
return
self
.
layers
[
layer_number
]
...
@@ -609,14 +626,32 @@ class ParallelTransformer(MegatronModule):
...
@@ -609,14 +626,32 @@ class ParallelTransformer(MegatronModule):
return
x_
return
x_
return
custom_forward
return
custom_forward
# Make sure memory is freed.
if
self
.
activations_checkpoint_method
==
'uniform'
:
mpu
.
reset_checkpointed_activations_memory_buffer
()
# Uniformly divide the total number of Transformer layers and checkpoint
l
=
0
# the input activation of each divided chunk.
while
l
<
self
.
num_layers
:
# A method to further reduce memory usage reducing checkpoints.
hidden_states
=
mpu
.
checkpoint
(
l
=
0
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
while
l
<
self
.
num_layers
:
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
=
mpu
.
checkpoint
(
l
+=
self
.
checkpoint_num_layers
custom
(
l
,
l
+
self
.
activations_checkpoint_num_layers
),
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
activations_checkpoint_num_layers
elif
self
.
activations_checkpoint_method
==
'block'
:
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
activations_checkpoint_num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
1
),
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
raise
ValueError
(
"Invalid activation checkpoint method."
)
return
hidden_states
return
hidden_states
...
@@ -630,18 +665,14 @@ class ParallelTransformer(MegatronModule):
...
@@ -630,18 +665,14 @@ class ParallelTransformer(MegatronModule):
forward_step_func"""
forward_step_func"""
self
.
input_tensor
=
input_tensor
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
get_key_value
=
False
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
):
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# Checks.
# Checks.
if
layer_past
is
not
None
:
if
inference_params
:
assert
get_key_value
,
\
assert
self
.
activations_checkpoint_method
is
None
,
\
'for not None values in layer_past, '
\
'inference does not work with activation checkpointing'
'expected get_key_value to be set'
if
get_key_value
:
assert
not
self
.
checkpoint_activations
,
\
'get_key_value does not work with '
\
'activation checkpointing'
if
self
.
pre_process
:
if
self
.
pre_process
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
...
@@ -658,28 +689,21 @@ class ParallelTransformer(MegatronModule):
...
@@ -658,28 +689,21 @@ class ParallelTransformer(MegatronModule):
if
encoder_output
is
not
None
:
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
checkpoint_activati
on
s
:
if
self
.
activations_checkpoint_method
is
not
N
on
e
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
attention_mask
,
encoder_output
,
encoder_output
,
enc_dec_attn_mask
)
enc_dec_attn_mask
)
else
:
else
:
if
get_key_value
:
presents
=
[]
for
index
in
range
(
self
.
num_layers
):
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
layer
=
self
.
_get_layer
(
index
)
past
=
None
hidden_states
=
layer
(
if
layer_past
is
not
None
:
hidden_states
,
past
=
layer_past
[
index
]
attention_mask
,
hidden_states
=
layer
(
hidden_states
,
encoder_output
=
encoder_output
,
attention_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
encoder_output
=
encoder_output
,
inference_params
=
inference_params
)
enc_dec_attn_mask
=
enc_dec_attn_mask
,
layer_past
=
past
,
get_key_value
=
get_key_value
)
if
get_key_value
:
hidden_states
,
present
=
hidden_states
presents
.
append
(
present
)
# Final layer norm.
# Final layer norm.
if
self
.
post_process
:
if
self
.
post_process
:
...
@@ -688,7 +712,5 @@ class ParallelTransformer(MegatronModule):
...
@@ -688,7 +712,5 @@ class ParallelTransformer(MegatronModule):
output
=
self
.
final_layernorm
(
hidden_states
)
output
=
self
.
final_layernorm
(
hidden_states
)
else
:
else
:
output
=
hidden_states
output
=
hidden_states
if
get_key_value
:
output
=
[
output
,
presents
]
return
output
return
output
megatron/mpu/__init__.py
View file @
6fd0b406
...
@@ -31,6 +31,10 @@ from .initialize import get_pipeline_model_parallel_group
...
@@ -31,6 +31,10 @@ from .initialize import get_pipeline_model_parallel_group
from
.initialize
import
get_tensor_model_parallel_rank
,
set_tensor_model_parallel_rank
from
.initialize
import
get_tensor_model_parallel_rank
,
set_tensor_model_parallel_rank
from
.initialize
import
get_pipeline_model_parallel_rank
,
set_pipeline_model_parallel_rank
from
.initialize
import
get_pipeline_model_parallel_rank
,
set_pipeline_model_parallel_rank
from
.initialize
import
is_pipeline_first_stage
,
is_pipeline_last_stage
from
.initialize
import
is_pipeline_first_stage
,
is_pipeline_last_stage
from
.initialize
import
is_rank_in_embedding_group
from
.initialize
import
is_pipeline_stage_before_split
,
is_pipeline_stage_after_split
from
.initialize
import
is_pipeline_stage_at_split
from
.initialize
import
get_num_layers
from
.initialize
import
get_tensor_model_parallel_src_rank
from
.initialize
import
get_tensor_model_parallel_src_rank
from
.initialize
import
get_pipeline_model_parallel_first_rank
from
.initialize
import
get_pipeline_model_parallel_first_rank
from
.initialize
import
get_pipeline_model_parallel_last_rank
from
.initialize
import
get_pipeline_model_parallel_last_rank
...
@@ -56,9 +60,7 @@ from .mappings import scatter_to_tensor_model_parallel_region
...
@@ -56,9 +60,7 @@ from .mappings import scatter_to_tensor_model_parallel_region
from
.random
import
checkpoint
from
.random
import
checkpoint
from
.random
import
get_cuda_rng_tracker
from
.random
import
get_cuda_rng_tracker
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.random
import
gather_split_1d_tensor
from
.random
import
gather_split_1d_tensor
from
.random
import
split_tensor_into_1d_equal_chunks
from
.random
import
split_tensor_into_1d_equal_chunks
...
...
megatron/mpu/initialize.py
View file @
6fd0b406
...
@@ -34,6 +34,7 @@ _DATA_PARALLEL_GROUP = None
...
@@ -34,6 +34,7 @@ _DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
None
# These values enable us to change the mpu sizes on the fly.
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
...
@@ -41,8 +42,11 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
...
@@ -41,8 +42,11 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS
=
None
# A list of global ranks for each pipeline group to ease calculation of the source
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
# rank when broadcasting from the first or last pipeline stage
.
_PIPELINE_GLOBAL_RANKS
=
None
_PIPELINE_GLOBAL_RANKS
=
None
def
is_unitialized
():
def
is_unitialized
():
...
@@ -52,13 +56,19 @@ def is_unitialized():
...
@@ -52,13 +56,19 @@ def is_unitialized():
def
initialize_model_parallel
(
tensor_model_parallel_size_
=
1
,
def
initialize_model_parallel
(
tensor_model_parallel_size_
=
1
,
pipeline_model_parallel_size_
=
1
,
pipeline_model_parallel_size_
=
1
,
virtual_pipeline_model_parallel_size_
=
None
):
virtual_pipeline_model_parallel_size_
=
None
,
pipeline_model_parallel_split_rank_
=
None
):
"""
"""
Initialize model data parallel groups.
Initialize model data parallel groups.
Arguments:
Arguments:
tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
rank in pipeline with split point.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
...
@@ -101,6 +111,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
...
@@ -101,6 +111,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
virtual_pipeline_model_parallel_size_
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
virtual_pipeline_model_parallel_size_
if
pipeline_model_parallel_split_rank_
is
not
None
:
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
pipeline_model_parallel_split_rank_
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
# Build the data-parallel groups.
# Build the data-parallel groups.
...
@@ -148,6 +162,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
...
@@ -148,6 +162,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
,
\
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
,
\
'pipeline model parallel group is already initialized'
'pipeline model parallel group is already initialized'
global
_EMBEDDING_GROUP
global
_EMBEDDING_GROUP
global
_EMBEDDING_GLOBAL_RANKS
assert
_EMBEDDING_GROUP
is
None
,
\
assert
_EMBEDDING_GROUP
is
None
,
\
'embedding group is already initialized'
'embedding group is already initialized'
for
i
in
range
(
num_pipeline_model_parallel_groups
):
for
i
in
range
(
num_pipeline_model_parallel_groups
):
...
@@ -161,11 +176,18 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
...
@@ -161,11 +176,18 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# first and last stages).
# first and last stages).
if
len
(
ranks
)
>
1
:
if
len
(
ranks
)
>
1
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
if
pipeline_model_parallel_split_rank_
is
not
None
and
\
pipeline_model_parallel_split_rank_
not
in
embedding_ranks
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank_
],
ranks
[
-
1
]]
else
:
else
:
embedding_ranks
=
ranks
embedding_ranks
=
ranks
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
)
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
)
if
rank
in
embedding_ranks
:
if
rank
in
embedding_ranks
:
_EMBEDDING_GROUP
=
group
_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_EMBEDDING_GLOBAL_RANKS
=
embedding_ranks
def
model_parallel_is_initialized
():
def
model_parallel_is_initialized
():
...
@@ -268,6 +290,30 @@ def get_pipeline_model_parallel_rank():
...
@@ -268,6 +290,30 @@ def get_pipeline_model_parallel_rank():
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
def
get_num_layers
(
args
,
is_encoder_and_decoder_model
):
"""Compute the number of transformer layers resident on the current rank."""
if
get_pipeline_model_parallel_world_size
()
>
1
:
if
is_encoder_and_decoder_model
:
assert
args
.
pipeline_model_parallel_split_rank
is
not
None
num_ranks_in_encoder
=
args
.
pipeline_model_parallel_split_rank
num_ranks_in_decoder
=
get_pipeline_model_parallel_world_size
()
-
num_ranks_in_encoder
assert
args
.
num_layers
%
num_ranks_in_encoder
==
0
,
\
'num_layers must be divisible by number of ranks given to encoder'
assert
args
.
num_layers
%
num_ranks_in_decoder
==
0
,
\
'num_layers must be divisible by number of ranks given to decoder'
if
is_pipeline_stage_before_split
():
num_layers
=
args
.
num_layers
//
num_ranks_in_encoder
else
:
num_layers
=
args
.
num_layers
//
num_ranks_in_decoder
else
:
assert
args
.
num_layers
%
get_pipeline_model_parallel_world_size
()
==
0
,
\
'num_layers must be divisible by pipeline_model_parallel_size'
num_layers
=
args
.
num_layers
//
get_pipeline_model_parallel_world_size
()
else
:
num_layers
=
args
.
num_layers
return
num_layers
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
if
not
ignore_virtual
:
...
@@ -290,6 +336,61 @@ def is_pipeline_last_stage(ignore_virtual=False):
...
@@ -290,6 +336,61 @@ def is_pipeline_last_stage(ignore_virtual=False):
get_pipeline_model_parallel_world_size
()
-
1
)
get_pipeline_model_parallel_world_size
()
-
1
)
def
is_rank_in_embedding_group
(
ignore_virtual
=
False
):
"""Return true if current rank is in embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
global
_EMBEDDING_GLOBAL_RANKS
if
ignore_virtual
:
return
rank
in
_EMBEDDING_GLOBAL_RANKS
if
rank
in
_EMBEDDING_GLOBAL_RANKS
:
if
rank
==
_EMBEDDING_GLOBAL_RANKS
[
0
]:
return
is_pipeline_first_stage
(
ignore_virtual
=
False
)
elif
rank
==
_EMBEDDING_GLOBAL_RANKS
[
-
1
]:
return
is_pipeline_last_stage
(
ignore_virtual
=
False
)
else
:
return
True
return
False
def
is_pipeline_stage_before_split
(
rank
=
None
):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if
get_pipeline_model_parallel_world_size
()
==
1
:
return
True
if
rank
is
None
:
rank
=
get_pipeline_model_parallel_rank
()
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
is
None
:
return
True
if
rank
<
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
:
return
True
return
False
def
is_pipeline_stage_after_split
(
rank
=
None
):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if
get_pipeline_model_parallel_world_size
()
==
1
:
return
True
if
rank
is
None
:
rank
=
get_pipeline_model_parallel_rank
()
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
is
None
:
return
True
if
rank
>=
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
:
return
True
return
False
def
is_pipeline_stage_at_split
():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank
=
get_pipeline_model_parallel_rank
()
return
is_pipeline_stage_before_split
(
rank
)
and
\
is_pipeline_stage_after_split
(
rank
+
1
)
def
get_virtual_pipeline_model_parallel_rank
():
def
get_virtual_pipeline_model_parallel_rank
():
"""Return the virtual pipeline-parallel rank."""
"""Return the virtual pipeline-parallel rank."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
...
@@ -356,9 +457,13 @@ def get_data_parallel_rank():
...
@@ -356,9 +457,13 @@ def get_data_parallel_rank():
def
destroy_model_parallel
():
def
destroy_model_parallel
():
"""Set the groups to none."""
"""Set the groups to none."""
global
_MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP
=
None
global
_TENSOR_MODEL_PARALLEL_GROUP
global
_TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP
=
None
_TENSOR_MODEL_PARALLEL_GROUP
=
None
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
global
_EMBEDDING_GROUP
_EMBEDDING_GROUP
=
None
megatron/mpu/layers.py
View file @
6fd0b406
...
@@ -27,6 +27,7 @@ from torch.nn.parameter import Parameter
...
@@ -27,6 +27,7 @@ from torch.nn.parameter import Parameter
from
.initialize
import
get_tensor_model_parallel_rank
from
.initialize
import
get_tensor_model_parallel_rank
from
.initialize
import
get_tensor_model_parallel_world_size
from
.initialize
import
get_tensor_model_parallel_world_size
from
.initialize
import
get_tensor_model_parallel_group
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
...
@@ -198,6 +199,37 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -198,6 +199,37 @@ class VocabParallelEmbedding(torch.nn.Module):
return
output
return
output
class
ColumnParallelLinearWithAsyncAllreduce
(
torch
.
autograd
.
Function
):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
output
=
torch
.
matmul
(
input
,
weight
.
t
())
if
bias
is
not
None
:
output
=
output
+
bias
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
grad_input
=
grad_output
.
matmul
(
weight
)
# Asyncronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
grad_weight
=
grad_output
.
t
().
matmul
(
input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
"""Linear layer with column parallelism.
"""Linear layer with column parallelism.
...
@@ -256,7 +288,7 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -256,7 +288,7 @@ class ColumnParallelLinear(torch.nn.Module):
device
=
torch
.
cuda
.
current_device
(),
dtype
=
args
.
params_dtype
))
device
=
torch
.
cuda
.
current_device
(),
dtype
=
args
.
params_dtype
))
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
stride
)
partition_dim
=
0
,
stride
=
stride
)
if
bias
:
if
bias
:
if
args
.
use_cpu_initialization
:
if
args
.
use_cpu_initialization
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
bias
=
Parameter
(
torch
.
empty
(
...
@@ -272,21 +304,35 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -272,21 +304,35 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
bias
.
zero_
()
self
.
bias
.
zero_
()
else
:
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
async_tensor_model_parallel_allreduce
=
(
not
args
.
no_async_tensor_model_parallel_allreduce
and
world_size
>
1
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
# Set up backprop all-reduce.
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
if
self
.
async_tensor_model_parallel_allreduce
:
input_shape
=
input_
.
shape
input_
=
input_
.
view
(
input_shape
[
0
]
*
input_shape
[
1
],
input_shape
[
2
])
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel
=
ColumnParallelLinearWithAsyncAllreduce
.
apply
(
input_
,
self
.
weight
,
bias
)
output_parallel
=
output_parallel
.
view
(
input_shape
[
0
],
input_shape
[
1
],
output_parallel
.
shape
[
1
])
else
:
# Set up backprop all-reduce.
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
if
self
.
gather_output
:
if
self
.
gather_output
:
# All-gather across the partitions.
# All-gather across the partitions.
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
else
:
else
:
output
=
output_parallel
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
return
output
,
output_bias
...
@@ -316,8 +362,8 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -316,8 +362,8 @@ class RowParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
set to False. It returns the master weights
used for initialization.
used for initialization.
skip_bias_add: This was added to enable performance optimation
s
where bias
skip_bias_add: This was added to enable performance optim
iz
ation where bias
can be fused with other elementwise operations.
w
e skip
can be fused with other elementwise operations.
W
e skip
adding bias but instead return it.
adding bias but instead return it.
"""
"""
...
...
megatron/mpu/mappings.py
View file @
6fd0b406
...
@@ -20,7 +20,7 @@ from .utils import split_tensor_along_last_dim
...
@@ -20,7 +20,7 @@ from .utils import split_tensor_along_last_dim
def
_reduce
(
input_
):
def
_reduce
(
input_
):
"""All-reduce the
the
input tensor across model parallel group."""
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
get_tensor_model_parallel_world_size
()
==
1
:
if
get_tensor_model_parallel_world_size
()
==
1
:
...
...
megatron/mpu/random.py
View file @
6fd0b406
...
@@ -24,7 +24,6 @@ from torch import _C
...
@@ -24,7 +24,6 @@ from torch import _C
from
torch.cuda
import
_lazy_call
,
device
as
device_ctx_manager
from
torch.cuda
import
_lazy_call
,
device
as
device_ctx_manager
from
torch.utils.checkpoint
import
detach_variable
from
torch.utils.checkpoint
import
detach_variable
from
megatron
import
get_args
from
megatron.memory
import
allocate_mem_buff
from
megatron.memory
import
allocate_mem_buff
from
.initialize
import
get_data_parallel_rank
from
.initialize
import
get_data_parallel_rank
...
@@ -37,37 +36,6 @@ from .initialize import get_tensor_model_parallel_world_size
...
@@ -37,37 +36,6 @@ from .initialize import get_tensor_model_parallel_world_size
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
None
def
init_checkpointed_activations_memory_buffer
():
"""Initializ the memory buffer for the checkpointed activations."""
args
=
get_args
()
per_layer
=
args
.
micro_batch_size
*
args
.
max_position_embeddings
*
\
args
.
hidden_size
//
args
.
tensor_model_parallel_size
assert
args
.
num_layers
%
args
.
checkpoint_num_layers
==
0
,
\
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers
=
args
.
num_layers
//
args
.
checkpoint_num_layers
numel
=
per_layer
*
num_checkpointer_layers
dtype
=
torch
.
half
if
not
args
.
fp16
:
dtype
=
torch
.
float
global
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
None
,
\
'checkpointed activations memory buffer is already allocated.'
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
allocate_mem_buff
(
'checkpointed activations'
,
numel
,
dtype
,
track_usage
=
False
)
def
reset_checkpointed_activations_memory_buffer
():
"""Reset the memory used for checkpointing."""
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
None
:
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
.
reset
()
def
_set_cuda_rng_state
(
new_state
,
device
=-
1
):
def
_set_cuda_rng_state
(
new_state
,
device
=-
1
):
"""Sets the random number generator state of the current GPU.
"""Sets the random number generator state of the current GPU.
...
@@ -101,14 +69,21 @@ def _set_cuda_rng_state(new_state, device=-1):
...
@@ -101,14 +69,21 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call
(
cb
)
_lazy_call
(
cb
)
def
split_tensor_into_1d_equal_chunks
(
tensor
):
def
split_tensor_into_1d_equal_chunks
(
tensor
,
new_buffer
=
False
):
"""Break a tensor into equal 1D chunks."""
"""Break a tensor into equal 1D chunks."""
data
=
tensor
.
view
(
-
1
)
partition_size
=
torch
.
numel
(
tensor
)
//
\
partition_size
=
torch
.
numel
(
data
)
//
get_tensor_model_parallel_world_size
()
get_tensor_model_parallel_world_size
()
start_index
=
partition_size
*
get_tensor_model_parallel_rank
()
start_index
=
partition_size
*
get_tensor_model_parallel_rank
()
end_index
=
start_index
+
partition_size
end_index
=
start_index
+
partition_size
return
data
[
start_index
:
end_index
]
if
new_buffer
:
data
=
torch
.
empty
(
partition_size
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
data
.
copy_
(
tensor
.
view
(
-
1
)[
start_index
:
end_index
])
else
:
data
=
tensor
.
view
(
-
1
)[
start_index
:
end_index
]
return
data
def
gather_split_1d_tensor
(
tensor
):
def
gather_split_1d_tensor
(
tensor
):
"""Opposite of above function, gather values from model parallel ranks."""
"""Opposite of above function, gather values from model parallel ranks."""
...
@@ -250,8 +225,10 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -250,8 +225,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset.
tracked/set/reset.
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
run_function
,
*
args
):
def
forward
(
ctx
,
run_function
,
distribute_checkpointed_activations
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
run_function
=
run_function
ctx
.
distribute_checkpointed_activations
\
=
distribute_checkpointed_activations
# Copy the rng states.
# Copy the rng states.
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
@@ -263,16 +240,14 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -263,16 +240,14 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
# the chunk corresponding to the current rank.
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
N
on
e
:
if
distribute_checkpointed_activati
on
s
:
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
args
[
0
].
data
=
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
)
args
[
0
].
data
=
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
,
args
[
0
].
data
=
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
.
add
(
new_buffer
=
True
)
args
[
0
].
data
)
# Store everything.
# Store everything.
ctx
.
save_for_backward
(
*
args
)
ctx
.
save_for_backward
(
*
args
)
return
outputs
return
outputs
@
staticmethod
@
staticmethod
...
@@ -281,7 +256,7 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -281,7 +256,7 @@ class CheckpointFunction(torch.autograd.Function):
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
"please use .backward() if possible"
)
inputs
=
ctx
.
saved_tensors
inputs
=
ctx
.
saved_tensors
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
N
on
e
:
if
ctx
.
distribute_checkpointed_activati
on
s
:
inputs
[
0
].
data
=
gather_split_1d_tensor
(
inputs
[
0
].
data
)
inputs
[
0
].
data
=
gather_split_1d_tensor
(
inputs
[
0
].
data
)
inputs
[
0
].
data
=
inputs
[
0
].
data
.
view
(
ctx
.
input_0_shape
)
inputs
[
0
].
data
=
inputs
[
0
].
data
.
view
(
ctx
.
input_0_shape
)
...
@@ -310,10 +285,11 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -310,10 +285,11 @@ class CheckpointFunction(torch.autograd.Function):
torch
.
autograd
.
backward
(
outputs
,
args
)
torch
.
autograd
.
backward
(
outputs
,
args
)
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
inp
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
inp
for
inp
in
detached_inputs
)
for
inp
in
detached_inputs
)
return
(
None
,)
+
grads
return
(
None
,
None
)
+
grads
def
checkpoint
(
function
,
*
args
):
def
checkpoint
(
function
,
distribute_checkpointed_activations
,
*
args
):
"""Checkpoint a model or part of the model.
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
This has been directly copied from torch.utils.checkpoint."""
return
CheckpointFunction
.
apply
(
function
,
*
args
)
return
CheckpointFunction
.
apply
(
function
,
distribute_checkpointed_activations
,
*
args
)
megatron/optimizer/__init__.py
View file @
6fd0b406
...
@@ -100,10 +100,12 @@ def get_megatron_optimizer(model):
...
@@ -100,10 +100,12 @@ def get_megatron_optimizer(model):
args
.
clip_grad
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
,
args
.
bf16
,
args
.
bf16
,
grad_scaler
)
grad_scaler
)
# FP32.
# FP32.
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
)
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
)
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment