Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3aca1415
Commit
3aca1415
authored
Apr 29, 2024
by
liangjing
Browse files
Merge branch 'megatron-lm_dtk24.04' into 'main'
Megatron lm dtk24.04 See merge request
!1
parents
0024a5c6
1005e9d3
Pipeline
#1806
passed with stage
Changes
204
Pipelines
3
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
912 additions
and
2247 deletions
+912
-2247
megatron/global_vars.py
megatron/global_vars.py
+2
-2
megatron/initialize.py
megatron/initialize.py
+121
-72
megatron/model/bert_model.py
megatron/model/bert_model.py
+17
-26
megatron/model/classification.py
megatron/model/classification.py
+3
-5
megatron/model/distributed.py
megatron/model/distributed.py
+2
-2
megatron/model/enums.py
megatron/model/enums.py
+4
-1
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+1
-1
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+3
-3
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+12
-13
megatron/model/language_model.py
megatron/model/language_model.py
+78
-93
megatron/model/module.py
megatron/model/module.py
+13
-15
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+3
-5
megatron/model/retro_transformer.py
megatron/model/retro_transformer.py
+0
-1731
megatron/model/t5_model.py
megatron/model/t5_model.py
+7
-17
megatron/model/transformer.py
megatron/model/transformer.py
+627
-233
megatron/model/vision/classification.py
megatron/model/vision/classification.py
+2
-1
megatron/model/vision/dino.py
megatron/model/vision/dino.py
+9
-7
megatron/model/vision/inpainting.py
megatron/model/vision/inpainting.py
+3
-2
megatron/model/vision/mit_backbone.py
megatron/model/vision/mit_backbone.py
+2
-7
megatron/model/vision/vit_backbone.py
megatron/model/vision/vit_backbone.py
+3
-11
No files found.
megatron/global_vars.py
View file @
3aca1415
...
...
@@ -80,7 +80,7 @@ def _set_signal_handler():
def
set_global_variables
(
args
):
def
set_global_variables
(
args
,
build_tokenizer
=
True
):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
assert
args
is
not
None
...
...
@@ -89,7 +89,7 @@ def set_global_variables(args):
set_args
(
args
)
_build_num_microbatches_calculator
(
args
)
if
args
.
vocab_file
or
args
.
tokenizer
_model
:
if
build_
tokenizer
:
_
=
_build_tokenizer
(
args
)
_set_tensorboard_writer
(
args
)
_set_adlr_autoresume
(
args
)
...
...
megatron/initialize.py
View file @
3aca1415
...
...
@@ -15,36 +15,40 @@ from megatron import get_adlr_autoresume
from
megatron
import
get_args
from
megatron
import
get_tensorboard_writer
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.arguments
import
(
parse_args
,
validate_args
)
from
megatron.arguments
import
parse_args
,
validate_args
from
megatron.checkpointing
import
load_args_from_checkpoint
from
megatron.global_vars
import
set_global_variables
from
megatron.model.transformer
import
bias_dropout_add_fused_train
from
megatron.model.fused_bias_gelu
import
bias_gelu
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
,
allow_no_cuda
=
False
):
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
,
allow_no_cuda
=
False
,
):
"""Set global variables, initialize distributed, and
set autoresume and random seeds.
`allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know
`allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know
what you are doing.
Returns a function to finalize distributed env initialization
Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True)
"""
if
not
allow_no_cuda
:
# Make sure cuda is available.
assert
torch
.
cuda
.
is_available
(),
'
Megatron requires CUDA.
'
assert
torch
.
cuda
.
is_available
(),
"
Megatron requires CUDA.
"
# Parse arguments
args
=
parse_args
(
extra_args_provider
,
ignore_unknown_args
)
if
args
.
use_checkpoint_args
or
args_defaults
.
get
(
'
use_checkpoint_args
'
,
False
):
assert
args
.
load
is
not
None
,
'
--use-checkpoints-args requires --load argument
'
if
args
.
use_checkpoint_args
or
args_defaults
.
get
(
"
use_checkpoint_args
"
,
False
):
assert
args
.
load
is
not
None
,
"
--use-checkpoints-args requires --load argument
"
load_args_from_checkpoint
(
args
)
validate_args
(
args
,
args_defaults
)
# set global args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables
(
args
)
...
...
@@ -54,16 +58,16 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args
=
get_args
()
# Pytorch distributed.
_initialize_distributed
()
# Random seeds for reproducibility.
if
args
.
rank
==
0
:
print
(
'
> setting random seeds to {} ...
'
.
format
(
args
.
seed
))
print
(
"
> setting random seeds to {} ...
"
.
format
(
args
.
seed
))
_set_random_seed
(
args
.
seed
,
args
.
data_parallel_random_init
)
args
=
get_args
()
if
args
.
lazy_mpu_init
:
if
args
.
lazy_mpu_init
:
# TODO is this still a necessary option?
args
.
use_cpu_initialization
=
True
args
.
use_cpu_initialization
=
True
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
mpu
.
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
...
...
@@ -95,11 +99,15 @@ def _compile_dependencies():
# TODO: move this to ninja
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
print
(
'
> compiling dataset index builder ...
'
)
print
(
"
> compiling dataset index builder ...
"
)
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
print
(
'>>> done with dataset index builder. Compilation time: {:.3f} '
'seconds'
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
)
print
(
">>> done with dataset index builder. Compilation time: {:.3f} "
"seconds"
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
,
)
# ==================
# Load fused kernels
...
...
@@ -107,41 +115,51 @@ def _compile_dependencies():
# Custom kernel constraints check.
seq_len
=
args
.
seq_length
attn_batch_size
=
\
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
\
args
.
micro_batch_size
attn_batch_size
=
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
args
.
micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
4096
and
\
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
custom_kernel_constraint
=
(
seq_len
>
16
and
seq_len
<=
16384
and
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
)
# Print a warning.
if
not
((
args
.
fp16
or
args
.
bf16
)
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
):
if
not
(
(
args
.
fp16
or
args
.
bf16
)
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
):
if
args
.
rank
==
0
:
print
(
'WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default'
' back to unfused kernel invocations.'
,
flush
=
True
)
print
(
"WARNING: constraints for invoking optimized"
" fused softmax kernel are not met. We default"
" back to unfused kernel invocations."
,
flush
=
True
,
)
# Always build on rank zero first.
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
print
(
'
> compiling and loading fused kernels ...
'
,
flush
=
True
)
fused_kernels
.
load
(
args
)
print
(
"
> compiling and loading fused kernels ...
"
,
flush
=
True
)
#
fused_kernels.load(args)
torch
.
distributed
.
barrier
()
else
:
torch
.
distributed
.
barrier
()
fused_kernels
.
load
(
args
)
#
fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'>>> done with compiling and loading fused kernels. '
'Compilation time: {:.3f} seconds'
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
)
print
(
">>> done with compiling and loading fused kernels. "
"Compilation time: {:.3f} seconds"
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
,
)
def
_initialize_distributed
():
...
...
@@ -152,45 +170,58 @@ def _initialize_distributed():
if
torch
.
distributed
.
is_initialized
():
if
args
.
rank
==
0
:
print
(
'torch distributed is already initialized, '
'skipping initialization ...'
,
flush
=
True
)
args
.
rank
=
torch
.
distributed
.
get_rank
()
args
.
world_size
=
torch
.
distributed
.
get_world_size
()
print
(
"torch distributed is already initialized, "
"skipping initialization ..."
,
flush
=
True
,
)
#args.rank = torch.distributed.get_rank()
#args.world_size = torch.distributed.get_world_size()
else
:
if
args
.
rank
==
0
:
print
(
'
> initializing torch distributed ...
'
,
flush
=
True
)
print
(
"
> initializing torch distributed ...
"
,
flush
=
True
)
# Manually set the device ids.
if
device_count
>
0
:
device
=
args
.
rank
%
device_count
if
args
.
local_rank
is
not
None
:
assert
args
.
local_rank
==
device
,
\
'expected local-rank to be the same as rank % device-count.'
assert
(
args
.
local_rank
==
device
),
"expected local-rank to be the same as rank % device-count."
else
:
args
.
local_rank
=
device
torch
.
cuda
.
set_device
(
device
)
# Call the init process
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
timeout
=
timedelta
(
minutes
=
args
.
distributed_timeout_minutes
))
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
init_method
=
args
.
dist_url
,
timeout
=
timedelta
(
minutes
=
args
.
distributed_timeout_minutes
),
)
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
if
device_count
>
0
:
if
mpu
.
model_parallel_is_initialized
():
print
(
'
model parallel is already initialized
'
)
print
(
"
model parallel is already initialized
"
)
else
:
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_split_rank
)
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_split_rank
,
args
.
fp8
is
not
None
,
)
if
args
.
rank
==
0
:
print
(
f
'> initialized tensor model parallel with size '
f
'
{
mpu
.
get_tensor_model_parallel_world_size
()
}
'
)
print
(
f
'> initialized pipeline model parallel with size '
f
'
{
mpu
.
get_pipeline_model_parallel_world_size
()
}
'
)
print
(
f
"> initialized tensor model parallel with size "
f
"
{
mpu
.
get_tensor_model_parallel_world_size
()
}
"
)
print
(
f
"> initialized pipeline model parallel with size "
f
"
{
mpu
.
get_pipeline_model_parallel_world_size
()
}
"
)
def
_init_autoresume
():
...
...
@@ -216,7 +247,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
if
torch
.
cuda
.
device_count
()
>
0
:
tensor_parallel
.
model_parallel_cuda_manual_seed
(
seed
)
else
:
raise
ValueError
(
'
Seed ({}) should be a positive integer.
'
.
format
(
seed
))
raise
ValueError
(
"
Seed ({}) should be a positive integer.
"
.
format
(
seed
))
def
write_args_to_tensorboard
():
...
...
@@ -225,15 +256,14 @@ def write_args_to_tensorboard():
writer
=
get_tensorboard_writer
()
if
writer
:
for
arg
in
vars
(
args
):
writer
.
add_text
(
arg
,
str
(
getattr
(
args
,
arg
)),
global_step
=
args
.
iteration
)
writer
.
add_text
(
arg
,
str
(
getattr
(
args
,
arg
)),
global_step
=
args
.
iteration
)
def
set_jit_fusion_options
():
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
10
):
# nvfuser
torch
.
_C
.
_jit_set_profiling_executor
(
True
)
...
...
@@ -241,7 +271,7 @@ def set_jit_fusion_options():
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
False
)
torch
.
_C
.
_jit_set_texpr_fuser_enabled
(
False
)
torch
.
_C
.
_jit_set_nvfuser_enabled
(
Tru
e
)
torch
.
_C
.
_jit_set_nvfuser_enabled
(
Fals
e
)
torch
.
_C
.
_debug_set_autodiff_subgraph_inlining
(
False
)
else
:
# legacy pytorch fuser
...
...
@@ -254,7 +284,7 @@ def set_jit_fusion_options():
def
_warmup_jit_function
():
"""
Compilie JIT functions before the main training steps
"""
"""Compilie JIT functions before the main training steps"""
args
=
get_args
()
if
args
.
bf16
:
dtype
=
torch
.
bfloat16
...
...
@@ -264,11 +294,20 @@ def _warmup_jit_function():
dtype
=
torch
.
float32
# Warmup fused bias+gelu
bias
=
torch
.
rand
(
args
.
ffn_hidden_size
//
args
.
tensor_model_parallel_size
,
dtype
=
dtype
,
device
=
'cuda'
)
input
=
torch
.
rand
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
ffn_hidden_size
//
args
.
tensor_model_parallel_size
),
dtype
=
dtype
,
device
=
'cuda'
)
bias
=
torch
.
rand
(
args
.
ffn_hidden_size
//
args
.
tensor_model_parallel_size
,
dtype
=
dtype
,
device
=
"cuda"
,
)
input
=
torch
.
rand
(
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
ffn_hidden_size
//
args
.
tensor_model_parallel_size
,
),
dtype
=
dtype
,
device
=
"cuda"
,
)
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for
bias_grad
,
input_grad
in
zip
([
True
,
True
],
[
False
,
True
]):
...
...
@@ -282,15 +321,25 @@ def _warmup_jit_function():
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
seq_length
=
args
.
seq_length
input
=
torch
.
rand
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
dtype
=
dtype
,
device
=
'cuda'
)
residual
=
torch
.
rand
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
dtype
=
dtype
,
device
=
'cuda'
)
bias
=
torch
.
rand
((
args
.
hidden_size
),
dtype
=
dtype
,
device
=
'cuda'
).
expand_as
(
residual
)
input
=
torch
.
rand
(
(
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
)
residual
=
torch
.
rand
(
(
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
)
bias
=
torch
.
rand
((
args
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
).
expand_as
(
residual
)
dropout_rate
=
0.1
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for
input_grad
,
bias_grad
,
residual_grad
in
zip
([
False
,
True
],
[
True
,
True
],
[
True
,
True
]):
for
input_grad
,
bias_grad
,
residual_grad
in
zip
(
[
False
,
True
],
[
True
,
True
],
[
True
,
True
]
):
input
.
requires_grad
=
input_grad
bias
.
requires_grad
=
bias_grad
residual
.
requires_grad
=
residual_grad
...
...
megatron/model/bert_model.py
View file @
3aca1415
...
...
@@ -47,31 +47,27 @@ class BertLMHead(MegatronModule):
"""Masked LM head for Bert
Arguments:
config: TransformerConfig object
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: whether output logits being distributed or not.
"""
def
__init__
(
self
,
mpu_vocab_size
,
hidden_size
,
init_method
,
layernorm_epsilon
,
parallel_output
):
super
(
BertLMHead
,
self
).
__init__
()
def
__init__
(
self
,
mpu_vocab_size
,
hidden_size
,
config
,
parallel_output
):
super
().
__init__
(
config
=
config
)
args
=
get_args
()
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
tensor_parallel
.
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
1
)
self
.
parallel_output
=
parallel_output
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
setattr
(
self
.
dense
.
weight
,
'sequence_parallel'
,
args
.
sequence_parallel
)
setattr
(
self
.
dense
.
bias
,
'sequence_parallel'
,
args
.
sequence_parallel
)
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
config
.
init_method
)
setattr
(
self
.
dense
.
weight
,
'sequence_parallel'
,
config
.
sequence_parallel
)
setattr
(
self
.
dense
.
bias
,
'sequence_parallel'
,
config
.
sequence_parallel
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
,
sequence_parallel
=
args
.
sequence_parallel
)
eps
=
config
.
layernorm_epsilon
,
sequence_parallel
=
config
.
sequence_parallel
)
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
if
args
.
openai_gelu
:
self
.
gelu
=
openai_gelu
...
...
@@ -124,12 +120,13 @@ class BertModel(MegatronModule):
"""Bert Language model."""
def
__init__
(
self
,
config
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
super
(
BertModel
,
self
).
__init__
(
)
super
(
).
__init__
(
config
=
config
)
args
=
get_args
()
# TODO this option is not yet implemented in BERT
...
...
@@ -145,29 +142,23 @@ class BertModel(MegatronModule):
if
self
.
return_embeddings
:
assert
self
.
post_process
and
self
.
add_binary_head
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
self
.
add_binary_head
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
initialize_word_embeddings
(
init_method_normal
)
self
.
initialize_word_embeddings
()
if
self
.
post_process
:
self
.
lm_head
=
BertLMHead
(
self
.
word_embeddings_weight
().
size
(
0
),
args
.
hidden_size
,
init_method
,
args
.
layernorm_epsilon
,
parallel_output
)
self
.
lm_head
=
BertLMHead
(
self
.
shared_embedding_or_output_weight
().
size
(
0
),
config
.
hidden_size
,
config
,
parallel_output
)
self
.
_lm_head_key
=
'lm_head'
self
.
binary_head
=
None
if
self
.
add_binary_head
:
self
.
binary_head
=
get_linear_layer
(
args
.
hidden_size
,
2
,
init_method
)
self
.
binary_head
=
get_linear_layer
(
config
.
hidden_size
,
2
,
config
.
init_method
)
self
.
_binary_head_key
=
'binary_head'
def
set_input_tensor
(
self
,
input_tensor
):
...
...
@@ -215,7 +206,7 @@ class BertModel(MegatronModule):
return
post_language_model_processing
(
lm_output
,
pooled_output
,
self
.
lm_head
,
self
.
binary_head
,
lm_labels
,
self
.
wor
d_embedding
s
_weight
(),
self
.
share
d_embedding
_or_output
_weight
(),
self
.
fp16_lm_cross_entropy
)
else
:
return
lm_output
...
...
megatron/model/classification.py
View file @
3aca1415
...
...
@@ -17,25 +17,23 @@ from .module import MegatronModule
class
Classification
(
MegatronModule
):
def
__init__
(
self
,
config
,
num_classes
,
num_tokentypes
=
2
,
pre_process
=
True
,
post_process
=
True
):
super
(
Classification
,
self
).
__init__
(
share_
word_
embeddings
=
False
)
super
(
).
__init__
(
config
=
config
,
share_embeddings
_and_output_weights
=
False
)
args
=
get_args
()
self
.
num_classes
=
num_classes
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
...
...
megatron/model/distributed.py
View file @
3aca1415
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
from
abc
import
ABC
from
abc
import
abstractmethod
...
...
@@ -73,7 +73,7 @@ class DistributedDataParallelBase(MegatronModule, ABC):
class
DistributedDataParallel
(
DistributedDataParallelBase
):
"""DDP with contiguous buffers options to stor
r
e and accumulate gradients.
"""DDP with contiguous buffers options to store and accumulate gradients.
This class:
- has the potential to reduce memory fragmentation.
- provides the option to do the gradient accumulation
...
...
megatron/model/enums.py
View file @
3aca1415
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
import
enum
class
LayerType
(
enum
.
Enum
):
encoder
=
1
decoder
=
2
retro_encoder
=
3
retro_decoder
=
4
retro_decoder_with_retriever
=
5
class
AttnType
(
enum
.
Enum
):
self_attn
=
1
...
...
megatron/model/fused_layer_norm.py
View file @
3aca1415
...
...
@@ -14,7 +14,7 @@ from megatron.core.utils import make_viewless_tensor
try
:
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNormFN
HAVE_PERSIST_LAYER_NORM
=
Tru
e
HAVE_PERSIST_LAYER_NORM
=
Fals
e
except
:
HAVE_PERSIST_LAYER_NORM
=
False
...
...
megatron/model/fused_softmax.py
View file @
3aca1415
...
...
@@ -155,12 +155,12 @@ class FusedScaleMaskSoftmax(nn.Module):
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
16
<
sk
<=
4096
# sk must be 16 ~
2048
and
16
<
sk
<=
16384
# sk must be 16 ~
16384
and
sq
%
4
==
0
# sq must be divisor of 4
and
sk
%
4
==
0
# sk must be divisor of 4
and
sk
%
4
==
0
# sk must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
4096
:
if
0
<=
sk
<=
16384
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
...
...
megatron/model/gpt_model.py
View file @
3aca1415
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
"""GPT-2 model."""
...
...
@@ -11,8 +11,6 @@ from .module import MegatronModule
from
.enums
import
AttnMaskType
from
.language_model
import
parallel_lm_logits
from
.language_model
import
get_language_model
from
.utils
import
init_method_normal
from
.utils
import
scaled_init_method_normal
def
post_language_model_processing
(
lm_output
,
labels
,
logit_weights
,
...
...
@@ -46,12 +44,13 @@ class GPTModel(MegatronModule):
"""GPT-2 Language model."""
def
__init__
(
self
,
config
,
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
super
(
GPTModel
,
self
).
__init__
(
share_word_embedding
s
=
not
args
.
untie_embeddings_and_output_weights
)
super
(
).
__init__
(
config
=
config
,
share_embeddings_and_output_weight
s
=
not
args
.
untie_embeddings_and_output_weights
)
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
...
...
@@ -60,39 +59,39 @@ class GPTModel(MegatronModule):
self
.
untie_embeddings_and_output_weights
=
args
.
untie_embeddings_and_output_weights
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
encoder_attn_mask_type
=
AttnMaskType
.
causal
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
if
not
args
.
untie_embeddings_and_output_weights
:
self
.
initialize_word_embeddings
(
init_method_normal
)
self
.
initialize_word_embeddings
()
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
ret_input_ids
=
None
,
ret_position_ids
=
None
,
ret_attn_mask
=
None
,
retriever_input_ids
=
None
,
retriever_position_ids
=
None
,
retriever_attn_mask
=
None
,
labels
=
None
,
tokentype_ids
=
None
,
inference_params
=
None
):
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
attention_mask
,
ret_input_ids
=
ret_input_ids
,
ret_position_ids
=
ret_position_ids
,
ret_attn_mask
=
ret_attn_mask
,
ret
riever
_input_ids
=
ret
riever
_input_ids
,
ret
riever
_position_ids
=
ret
riever
_position_ids
,
ret
riever
_attn_mask
=
ret
riever
_attn_mask
,
inference_params
=
inference_params
)
if
self
.
post_process
:
return
post_language_model_processing
(
lm_output
,
labels
,
self
.
language_model
.
output_layer
.
weight
if
self
.
untie_embeddings_and_output_weights
else
self
.
wor
d_embedding
s
_weight
(),
self
.
language_model
.
output_layer
.
weight
if
self
.
untie_embeddings_and_output_weights
else
self
.
share
d_embedding
_or_output
_weight
(),
self
.
parallel_output
,
self
.
fp16_lm_cross_entropy
)
else
:
...
...
megatron/model/language_model.py
View file @
3aca1415
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
"""Transformer based language model."""
...
...
@@ -7,11 +7,11 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core.enums
import
ModelType
from
megatron.core.models.common.rotary_pos_embedding
import
RotaryEmbedding
from
.enums
import
LayerType
,
AttnMaskType
from
.enums
import
AttnMaskType
,
LayerType
from
.module
import
MegatronModule
from
.retro_transformer
import
ParallelRetroEncoder
,
ParallelRetroTransformer
from
.rotary_pos_embedding
import
apply_rotary_pos_emb
,
RotaryEmbedding
from
.transformer
import
ParallelTransformer
from
.utils
import
get_linear_layer
from
.utils
import
init_method_normal
,
scaled_init_method_normal
...
...
@@ -39,7 +39,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias
=
bias
,
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
,
async_grad_allreduce
=
async_grad_allreduce
,
sequence_parallel
_enabled
=
args
.
sequence_parallel
)
sequence_parallel
=
args
.
sequence_parallel
)
# Gather if needed.
if
parallel_output
:
...
...
@@ -48,26 +48,24 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
return
tensor_parallel
.
gather_from_tensor_model_parallel_region
(
logits_parallel
)
def
get_language_model
(
num_tokentypes
,
add_pooler
,
encoder_attn_mask_type
,
init_method
=
None
,
scaled_init_method
=
None
,
add_encoder
=
True
,
def
get_language_model
(
config
,
num_tokentypes
,
add_pooler
,
encoder_attn_mask_type
,
add_encoder
=
True
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
pre_process
=
True
,
post_process
=
True
):
"""Build language model and return along with the key to save."""
args
=
get_args
()
if
config
.
init_method
is
None
:
config
.
init_method
=
init_method_normal
(
config
.
init_method_std
)
if
init_method
is
None
:
init_method
=
init_method_normal
(
args
.
init_method_std
)
if
scaled_init_method
is
None
:
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
if
config
.
output_layer_init_method
is
None
:
config
.
output_layer_init_method
=
scaled_init_method_normal
(
config
.
init_method_std
,
config
.
num_layers
)
# Language model.
language_model
=
TransformerLanguageModel
(
init_method
,
scaled_init_method
,
config
,
encoder_attn_mask_type
,
num_tokentypes
=
num_tokentypes
,
add_encoder
=
add_encoder
,
...
...
@@ -131,6 +129,10 @@ class Embedding(MegatronModule):
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
embedding_weights_in_fp32: casts word embedding weights to
fp32 before sampling. Required to
maintain reproducibility when
training in bf16.
"""
def
__init__
(
self
,
...
...
@@ -138,28 +140,26 @@ class Embedding(MegatronModule):
vocab_size
,
max_sequence_length
,
embedding_dropout_prob
,
init_method
,
num_tokentypes
=
0
):
config
,
num_tokentypes
=
0
,
embedding_weights_in_fp32
=
False
):
super
(
Embedding
,
self
).
__init__
()
self
.
hidden_size
=
hidden_size
self
.
init_method
=
init_method
self
.
init_method
=
config
.
init_method
self
.
num_tokentypes
=
num_tokentypes
args
=
get_args
()
# Word embeddings (parallel).
self
.
embedding_weights_in_fp32
=
embedding_weights_in_fp32
self
.
params_dtype
=
args
.
params_dtype
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
vocab_size
,
self
.
hidden_size
,
init_method
=
self
.
init_method
,
params_dtype
=
args
.
params_dtype
,
use_cpu_initialization
=
args
.
use_cpu_initialization
,
perform_initialization
=
args
.
perform_initialization
)
vocab_size
,
self
.
hidden_size
,
config
=
config
,
init_method
=
config
.
init_method
)
self
.
_word_embeddings_key
=
'word_embeddings'
# Position embedding (serial).
self
.
add_position_embedding
=
args
.
add_
position_embedding
self
.
add_position_embedding
=
args
.
position_embedding
_type
==
'learned_absolute'
if
self
.
add_position_embedding
:
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
max_sequence_length
,
self
.
hidden_size
)
...
...
@@ -182,7 +182,7 @@ class Embedding(MegatronModule):
else
:
self
.
tokentype_embeddings
=
None
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
sequence_parallel
=
args
.
sequence_parallel
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
...
...
@@ -217,7 +217,12 @@ class Embedding(MegatronModule):
def
forward
(
self
,
input_ids
,
position_ids
,
tokentype_ids
=
None
):
# Embeddings.
if
self
.
embedding_weights_in_fp32
:
self
.
word_embeddings
=
self
.
word_embeddings
.
to
(
torch
.
float32
)
words_embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
embedding_weights_in_fp32
:
words_embeddings
=
words_embeddings
.
to
(
self
.
params_dtype
)
self
.
word_embeddings
=
self
.
word_embeddings
.
to
(
self
.
params_dtype
)
if
self
.
add_position_embedding
:
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
words_embeddings
+
position_embeddings
...
...
@@ -326,8 +331,7 @@ class TransformerLanguageModel(MegatronModule):
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
config
,
encoder_attn_mask_type
,
num_tokentypes
=
0
,
add_encoder
=
True
,
...
...
@@ -337,21 +341,22 @@ class TransformerLanguageModel(MegatronModule):
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
# TODO: passing share_
word_
embeddings=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
# TODO: passing share_embeddings
_and_output_weights
=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
if
args
.
untie_embeddings_and_output_weights
:
assert
not
add_decoder
super
(
TransformerLanguageModel
,
self
).
__init__
(
share_
word_
embeddings
=
not
args
.
untie_embeddings_and_output_weights
)
super
(
TransformerLanguageModel
,
self
).
__init__
(
share_embeddings
_and_output_weights
=
not
args
.
untie_embeddings_and_output_weights
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
hidden_size
=
args
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
num_tokentypes
=
num_tokentypes
self
.
init_method
=
init_method
self
.
init_method
=
config
.
init_method
self
.
add_encoder
=
add_encoder
self
.
encoder_attn_mask_type
=
encoder_attn_mask_type
self
.
add_decoder
=
add_decoder
self
.
decoder_attn_mask_type
=
decoder_attn_mask_type
self
.
add_pooler
=
add_pooler
self
.
encoder_hidden_state
=
None
self
.
add_retriever
=
args
.
retro_add_retriever
self
.
untie_embeddings_and_output_weights
=
args
.
untie_embeddings_and_output_weights
# Embeddings.
...
...
@@ -360,14 +365,15 @@ class TransformerLanguageModel(MegatronModule):
args
.
padded_vocab_size
,
args
.
max_position_embeddings
,
args
.
hidden_dropout
,
self
.
init_method
,
self
.
num_tokentypes
)
config
,
self
.
num_tokentypes
,
args
.
embedding_weights_in_fp32
)
self
.
_embedding_key
=
'embedding'
# Rotary positional embeddings
self
.
use_rotary_position_embeddings
=
\
args
.
use_rotary_
position_embedding
s
if
args
.
use_rotary_position_embeddings
:
args
.
position_embedding
_type
==
'rope'
if
self
.
use_rotary_position_embeddings
:
self
.
seq_length
=
args
.
seq_length
rotary_dim
=
args
.
hidden_size
//
args
.
num_attention_heads
\
if
args
.
kv_channels
is
None
else
args
.
kv_channels
...
...
@@ -378,41 +384,22 @@ class TransformerLanguageModel(MegatronModule):
# partial rotary embeddings, which is better than full rotary
# Wang and Komatsuzaki et al
# https://github.com/kingoflolz/mesh-transformer-jax/
self
.
rotary_pos_emb
=
RotaryEmbedding
(
rotary_dim
)
# Retriever (bi-directional transformer with cross attention)
if
args
.
retro_add_retriever
:
self
.
retriever
=
ParallelRetroEncoder
(
self
.
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
self
.
pre_process
,
post_process
=
False
,
self
.
rotary_pos_emb
=
RotaryEmbedding
(
rotary_dim
,
seq_len_interpolation_factor
=
args
.
rotary_seq_len_interpolation_factor
)
self
.
_retriever_key
=
'retriever'
else
:
self
.
retriever
=
None
# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
if
self
.
add_encoder
:
if
args
.
retro_add_retriever
:
self
.
encoder
=
ParallelRetroTransformer
(
self
.
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
self
.
encoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
retriever
=
self
.
retriever
,
)
else
:
self
.
encoder
=
ParallelTransformer
(
self
.
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
self
.
encoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
)
self
.
encoder
=
ParallelTransformer
(
config
,
model_type
=
args
.
model_type
if
not
args
.
retro_add_retriever
\
else
ModelType
.
retro_decoder
,
self_attn_mask_type
=
self
.
encoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
)
self
.
_encoder_key
=
'encoder'
else
:
self
.
encoder
=
None
...
...
@@ -421,8 +408,8 @@ class TransformerLanguageModel(MegatronModule):
# architecture and in decoder-only stage).
if
self
.
add_decoder
:
self
.
decoder
=
ParallelTransformer
(
self
.
init_method
,
output_layer_init_method
,
config
,
model_type
=
args
.
model_type
,
layer_type
=
LayerType
.
decoder
,
self_attn_mask_type
=
self
.
decoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
...
...
@@ -441,8 +428,9 @@ class TransformerLanguageModel(MegatronModule):
self
.
output_layer
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
padded_vocab_size
,
bias
=
False
,
# Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
init_method
=
self
.
init_method
)
config
=
config
,
init_method
=
self
.
init_method
,
bias
=
False
)
# Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
self
.
_output_layer_key
=
'output_layer'
def
set_input_tensor
(
self
,
input_tensor
):
...
...
@@ -475,19 +463,14 @@ class TransformerLanguageModel(MegatronModule):
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
ret_input_ids
=
None
,
ret_position_ids
=
None
,
ret_attn_mask
=
None
,
retriever_input_ids
=
None
,
retriever_position_ids
=
None
,
retriever_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
inference_params
=
None
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
# Retriever embedding.
if
self
.
retriever
and
self
.
pre_process
:
retriever_input
=
self
.
embedding
(
ret_input_ids
,
ret_position_ids
,
tokentype_ids
=
tokentype_ids
)
else
:
retriever_input
=
None
# Encoder embedding.
if
self
.
pre_process
:
encoder_input
=
self
.
embedding
(
enc_input_ids
,
enc_position_ids
,
...
...
@@ -495,31 +478,33 @@ class TransformerLanguageModel(MegatronModule):
else
:
encoder_input
=
None
# Retriever embedding.
if
self
.
add_retriever
and
self
.
pre_process
:
retriever_input
=
self
.
embedding
(
retriever_input_ids
,
retriever_position_ids
,
tokentype_ids
=
tokentype_ids
)
else
:
retriever_input
=
None
# Rotary positional embeddings
rotary_pos_emb
=
None
if
self
.
use_rotary_position_embeddings
:
if
inference_params
is
not
None
:
rotary_pos_emb
=
\
self
.
rotary_pos_emb
(
inference_params
.
max_sequence_len
)
self
.
rotary_pos_emb
(
inference_params
.
max_sequence_len
gth
)
else
:
rotary_pos_emb
=
self
.
rotary_pos_emb
(
self
.
seq_length
)
# Run encoder.
if
enc_hidden_states
is
None
:
if
self
.
encoder
is
not
None
:
if
self
.
retriever
:
encoder_output
=
self
.
encoder
(
encoder_input
,
enc_attn_mask
,
retriever_output
=
retriever_input
,
retriever_attn_mask
=
ret_attn_mask
,
inference_params
=
inference_params
)
else
:
encoder_output
=
self
.
encoder
(
encoder_input
,
enc_attn_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
)
encoder_output
=
self
.
encoder
(
encoder_input
,
enc_attn_mask
,
retriever_input
=
retriever_input
,
retriever_attn_mask
=
retriever_attn_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
)
else
:
encoder_output
=
self
.
encoder_hidden_state
else
:
...
...
megatron/model/module.py
View file @
3aca1415
...
...
@@ -25,9 +25,10 @@ class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support
for pipelining."""
def
__init__
(
self
,
share_word_embedding
s
=
True
):
def
__init__
(
self
,
config
=
None
,
share_embeddings_and_output_weight
s
=
True
):
super
(
MegatronModule
,
self
).
__init__
()
self
.
share_word_embeddings
=
share_word_embeddings
self
.
config
=
config
self
.
share_embeddings_and_output_weights
=
share_embeddings_and_output_weights
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
...
...
@@ -36,21 +37,21 @@ class MegatronModule(torch.nn.Module):
return
self
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
def
wor
d_embedding
s
_weight
(
self
):
def
share
d_embedding
_or_output
_weight
(
self
):
if
self
.
pre_process
:
return
self
.
language_model
.
embedding
.
word_embeddings
.
weight
else
:
if
not
self
.
share_
word_
embeddings
:
raise
Exception
(
'
wor
d_embedding
s
_weight() called for last '
'stage, but share_
word_
embeddings is false'
)
if
not
self
.
share_embeddings
_and_output_weights
:
raise
Exception
(
'
share
d_embedding
_or_output
_weight() called for last '
'stage, but share_embeddings
_and_output_weights
is false'
)
return
self
.
word_embeddings
.
weight
def
initialize_word_embeddings
(
self
,
init_method_normal
):
def
initialize_word_embeddings
(
self
):
args
=
get_args
()
if
not
self
.
share_
word_
embeddings
:
if
not
self
.
share_embeddings
_and_output_weights
:
raise
Exception
(
'initialize_word_embeddings() was called but '
'share_
word_
embeddings is false'
)
'share_embeddings
_and_output_weights
is false'
)
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. Nothing to do if we aren't
...
...
@@ -76,11 +77,8 @@ class MegatronModule(torch.nn.Module):
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
args
.
padded_vocab_size
,
args
.
hidden_size
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
params_dtype
=
args
.
params_dtype
,
use_cpu_initialization
=
args
.
use_cpu_initialization
,
perform_initialization
=
args
.
perform_initialization
)
args
.
padded_vocab_size
,
self
.
config
.
hidden_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
)
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
...
...
@@ -103,7 +101,7 @@ class MegatronModule(torch.nn.Module):
# Ensure that first and last stages have the same initial parameter
# values.
if
mpu
.
is_rank_in_embedding_group
():
torch
.
distributed
.
all_reduce
(
self
.
wor
d_embedding
s
_weight
().
data
,
torch
.
distributed
.
all_reduce
(
self
.
share
d_embedding
_or_output
_weight
().
data
,
group
=
mpu
.
get_embedding_group
())
# Ensure that encoder(first stage) and decoder(split stage) position
...
...
megatron/model/multiple_choice.py
View file @
3aca1415
...
...
@@ -17,23 +17,21 @@ from .module import MegatronModule
class
MultipleChoice
(
MegatronModule
):
def
__init__
(
self
,
config
,
num_tokentypes
=
2
,
pre_process
=
True
,
post_process
=
True
):
super
(
MultipleChoice
,
self
).
__init__
(
share_
word_
embeddings
=
False
)
super
(
MultipleChoice
,
self
).
__init__
(
share_embeddings
_and_output_weights
=
False
)
args
=
get_args
()
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
...
...
megatron/model/retro_transformer.py
deleted
100644 → 0
View file @
0024a5c6
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Retro Transformer.
** Special note about this file **
Many classes and methods in this file directly parallel those in transformer.py
in name and utility. However, due to 1) subtle changes in the code over time
(i.e., transposes and contexts), and 2) other code that is soon to be merged,
this file will *temporarily* remain as is, until a larger integration is
complete.
"""
import
math
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
,
get_retro_args
,
get_tensorboard_writer
from
megatron.core
import
parallel_state
from
megatron.core
import
tensor_parallel
from
megatron.core
import
utils
as
core_utils
from
megatron.core.enums
import
ModelType
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model
import
LayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
,
init_method_normal
from
.module
import
MegatronModule
from
.transformer
import
_get_num_layers
,
ParallelMLP
,
NoopTransformerLayer
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
p: number of model parallel partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
"""
class
DropPath
(
MegatronModule
):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
*Note: differs from transformer.py/DropPath in hidden_state transpose.
"""
def
__init__
(
self
,
drop_prob
=
0.
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
hidden_state
):
if
self
.
drop_prob
==
0.
or
not
self
.
training
:
return
hidden_state
keep_prob
=
1
-
self
.
drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape
=
(
hidden_state
.
shape
[
0
],)
+
(
1
,)
*
(
hidden_state
.
ndim
-
1
)
random_tensor
=
keep_prob
+
\
torch
.
rand
(
shape
,
dtype
=
hidden_state
.
dtype
,
device
=
hidden_state
.
device
)
random_tensor
.
floor_
()
# binarize
output
=
hidden_state
.
div
(
keep_prob
)
*
random_tensor
return
output
class
SwitchMLP
(
MegatronModule
):
"""
Routes input to one of N MLP "experts"
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
super
(
SwitchMLP
,
self
).
__init__
()
args
=
get_args
()
self
.
router
=
torch
.
nn
.
Linear
(
args
.
hidden_size
,
args
.
num_experts
)
self
.
experts
=
torch
.
nn
.
ModuleList
()
for
i
in
range
(
args
.
num_experts
):
self
.
experts
.
append
(
ParallelMLP
(
init_method
,
output_layer_init_method
))
def
forward
(
self
,
hidden_states
):
# hidden_states: [b, s, h]
b
=
hidden_states
.
size
(
0
)
s
=
hidden_states
.
size
(
1
)
h
=
hidden_states
.
size
(
2
)
route
=
self
.
router
(
hidden_states
)
route
=
torch
.
nn
.
functional
.
softmax
(
route
,
dim
=
2
)
max_prob
,
max_ind
=
torch
.
max
(
route
,
dim
=
2
)
max_prob
=
torch
.
unsqueeze
(
max_prob
,
2
)
# [b s 1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [b, s, h] to [b*s, h].
# Each vector could be routed differently
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
size
(
2
))
# [b*s h]
max_prob
=
max_prob
.
view
(
-
1
,
max_prob
.
size
(
2
))
# [b*s 1]
max_ind
=
max_ind
.
view
(
-
1
)
# [b*s]
output_total
=
torch
.
empty_like
(
hidden_states
)
output_bias_total
=
torch
.
empty_like
(
hidden_states
)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for
expert_num
,
expert
in
enumerate
(
self
.
experts
):
local_indices
=
(
max_ind
==
expert_num
).
nonzero
()
hidden
=
hidden_states
[
local_indices
,:]
output
,
output_bias
=
expert
(
hidden
)
output_bias
=
output_bias
.
expand_as
(
output
)
output_total
[
local_indices
,:]
=
output
output_bias_total
[
local_indices
,:]
=
output_bias
output_total
=
output_total
*
max_prob
output_bias_total
=
output_bias_total
*
max_prob
output_total
=
output_total
.
view
(
b
,
s
,
h
)
output_bias_total
=
output_bias_total
.
view
(
b
,
s
,
h
)
return
output_total
,
output_bias_total
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
AttnMaskType
.
padding
):
super
(
ParallelAttention
,
self
).
__init__
()
args
=
get_args
()
self
.
fp16
=
args
.
fp16
self
.
bf16
=
args
.
bf16
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
self
.
params_dtype
=
args
.
params_dtype
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
# Per attention head and per partition values.
world_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_partition
=
core_utils
.
divide
(
projection_size
,
world_size
)
self
.
hidden_size_per_attention_head
=
core_utils
.
divide
(
projection_size
,
args
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
core_utils
.
divide
(
args
.
num_attention_heads
,
world_size
)
# Strided linear layer.
if
attention_type
==
AttnType
.
self_attn
:
self
.
query_key_value
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
3
*
projection_size
,
gather_output
=
False
,
init_method
=
init_method
)
else
:
assert
attention_type
==
AttnType
.
cross_attn
self
.
query
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
projection_size
,
gather_output
=
False
,
init_method
=
init_method
)
self
.
key_value
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
2
*
projection_size
,
gather_output
=
False
,
init_method
=
init_method
)
coeff
=
None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
if
self
.
apply_query_key_layer_scaling
:
coeff
=
self
.
layer_number
self
.
norm_factor
*=
coeff
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
self
.
bf16
,
self
.
attn_mask_type
,
args
.
masked_softmax_fusion
,
attention_mask_func
,
self
.
attention_softmax_in_fp32
,
coeff
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self
.
attention_dropout
=
torch
.
nn
.
Dropout
(
args
.
attention_dropout
)
# Output.
self
.
dense
=
tensor_parallel
.
RowParallelLinear
(
projection_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
):
return
torch
.
empty
(
inference_max_sequence_len
,
batch_size
,
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
dtype
=
self
.
params_dtype
,
device
=
torch
.
cuda
.
current_device
())
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
inference_params
=
None
):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if
inference_params
:
if
self
.
layer_number
not
in
inference_params
.
key_value_memory_dict
:
inf_max_seq_len
=
inference_params
.
max_sequence_len
inf_max_batch_size
=
inference_params
.
max_batch_size
inference_key_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
)
inference_value_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
)
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_key_memory
,
inference_value_memory
)
else
:
inference_key_memory
,
inference_value_memory
=
\
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
# =====================
# Query, Key, and Value
# =====================
if
self
.
attention_type
==
AttnType
.
self_attn
:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
3
*
self
.
hidden_size_per_attention_head
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(
query_layer
,
key_layer
,
value_layer
)
=
tensor_parallel
\
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
else
:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer
,
_
=
self
.
key_value
(
encoder_output
)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape
=
mixed_kv_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
2
*
self
.
hidden_size_per_attention_head
)
mixed_kv_layer
=
mixed_kv_layer
.
view
(
*
new_tensor_shape
)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(
key_layer
,
value_layer
)
=
tensor_parallel
\
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer
,
_
=
self
.
query
(
hidden_states
)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape
=
query_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ==================================
# Adjust key and value for inference
# ==================================
if
inference_params
:
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
key_layer
.
size
(
1
)
assert
batch_end
<=
inference_key_memory
.
size
(
1
)
sequence_start
=
inference_params
.
sequence_len_offset
sequence_end
=
sequence_start
+
key_layer
.
size
(
0
)
assert
sequence_end
<=
inference_key_memory
.
size
(
0
)
# Copy key and values.
inference_key_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
key_layer
inference_value_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
value_layer
key_layer
=
inference_key_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
value_layer
=
inference_value_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting result tensor: [b * np, sq, sk]
matmul_result
=
torch
.
empty
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
],
dtype
=
query_layer
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs
=
self
.
scale_mask_softmax
(
attention_scores
,
attention_mask
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
))
# change view [sk, b * np, hn]
value_layer
=
value_layer
.
view
(
value_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# change view [b * np, sq, sk]
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
# matmul: [b * np, sq, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
# change view [b, np, sq, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
\
(
self
.
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
# =================
# Output. [sq, b, h]
# =================
output
,
bias
=
self
.
dense
(
context_layer
)
return
output
,
bias
def
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out
=
torch
.
nn
.
functional
.
dropout
(
x
+
bias
,
p
=
prob
,
training
=
training
)
out
=
residual
+
out
return
out
def
get_bias_dropout_add
(
training
):
def
_bias_dropout_add
(
x
,
bias
,
residual
,
prob
):
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
)
return
_bias_dropout_add
@
torch
.
jit
.
script
def
bias_dropout_add_fused_train
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
True
)
@
torch
.
jit
.
script
def
bias_dropout_add_fused_inference
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
False
)
class
ParallelRetroTransformerEncoderLayer
(
MegatronModule
):
"""A single transformer layer for Retro Decoder with an retriever encoder inside and cross attention.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
drop_path_rate
=
0.
,
retriever
=
None
):
args
=
get_args
()
super
(
ParallelRetroTransformerEncoderLayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Retro Encoder
self
.
retriever
=
retriever
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# Self attention.
self
.
self_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
\
else
None
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
self
.
inter_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
cross_attn
)
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# MLP
if
args
.
num_experts
is
not
None
:
self
.
mlp
=
SwitchMLP
(
init_method
,
output_layer_init_method
)
else
:
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
retriever_output
,
retriever_attn_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
attention_mask
,
inference_params
=
inference_params
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
hidden_states
if
self
.
drop_path
is
None
:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
attention_output
+
attention_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
layernorm_input
=
residual
+
self
.
drop_path
(
out
)
# Layer norm post the self attention. # [ns, bs, d]
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
"""
notations:
l: number of chunks
m: number of token per chunk
bs: batch size
d: hidden size
k: number of neighbors
r: number of tokens per neighbors (neighbors + continuation)
"""
args
=
get_args
()
retro_args
=
get_retro_args
()
chunk_length
=
retro_args
.
retro_gpt_chunk_length
retrieved_length
=
retro_args
.
retro_gpt_retrieved_length
num_neighbors
=
args
.
retro_num_neighbors
ns
,
bs
,
d
=
layernorm_output
.
shape
l
=
int
(
np
.
ceil
(
ns
/
chunk_length
))
first_ns
=
ns
%
chunk_length
if
first_ns
>
0
:
first_chunk
,
rest_chunk
=
\
layernorm_output
[:
first_ns
],
layernorm_output
[
first_ns
:]
first_chunk
=
torch
.
nn
.
functional
.
pad
(
first_chunk
,
(
0
,
0
,
0
,
0
,
0
,
chunk_length
-
first_ns
),
'constant'
,
0
)
chunked_output
=
\
torch
.
cat
((
first_chunk
,
rest_chunk
),
dim
=
0
)
# [l * m, bs, d]
else
:
chunked_output
=
layernorm_output
# [l * m, bs, d]
chunked_output
=
chunked_output
\
.
reshape
(
l
,
chunk_length
,
bs
,
d
)
\
.
permute
(
1
,
2
,
0
,
3
)
\
.
reshape
(
chunk_length
,
bs
*
l
,
d
)
\
.
contiguous
()
# Get Encoder Output
retriever_output
=
self
.
retriever
(
retriever_output
,
retriever_attn_mask
,
retriever_output
=
chunked_output
,
retriever_attn_mask
=
retriever_attn_mask
,
inference_params
=
inference_params
)
# [r, k * bs * l , d]
retriever_output
=
retriever_output
.
reshape
(
retrieved_length
*
num_neighbors
,
bs
*
l
,
d
)
# [r * k, bs * l, d]
# Chunked Cross attention with Retriever Encoder
pad
=
(
ns
-
1
)
%
chunk_length
attending_chunks
=
layernorm_output
[
pad
:]
# [ns - m + 1, bs, d]
padded_chunks
=
torch
.
nn
.
functional
.
pad
(
attending_chunks
,
(
0
,
0
,
0
,
0
,
0
,
chunk_length
-
1
),
'constant'
,
0
)
# [ns, bs, d]
padded_chunked_output
=
padded_chunks
\
.
reshape
(
l
,
chunk_length
,
bs
,
d
)
\
.
permute
(
1
,
2
,
0
,
3
)
padded_chunked_output
=
padded_chunked_output
.
reshape
(
chunk_length
,
bs
*
l
,
d
).
contiguous
()
# [m, bs * l, d]
# attention_output: [m, bs * l, d]
# attention_bias: [d]
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
padded_chunked_output
,
# Q: main model embedding
None
,
encoder_output
=
retriever_output
)
# KV: retriever output embedding
# Residual connection
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
attention_output
),
torch
.
zeros_like
(
attention_output
),
self
.
hidden_dropout
)
layernorm_input
=
layernorm_input
\
.
reshape
(
chunk_length
,
bs
,
l
,
d
)
\
.
permute
(
2
,
0
,
1
,
3
)
# [l, m, bs, d]
layernorm_input
=
layernorm_input
.
reshape
(
chunk_length
*
l
,
bs
,
d
)
layernorm_input
=
torch
.
nn
.
functional
.
pad
(
layernorm_input
,
(
0
,
0
,
0
,
0
,
pad
,
0
),
'constant'
,
0
)[:
ns
]
# [ns, b, d]
layernorm_input
=
layernorm_input
+
residual
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
if
self
.
drop_path
is
None
:
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
+
mlp_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
return
output
,
retriever_output
class
ParallelRetroTransformerLayer
(
MegatronModule
):
"""A single transformer layer for Retro Decoder with cross attention.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
drop_path_rate
=
0.
):
args
=
get_args
()
super
(
ParallelRetroTransformerLayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# Self attention.
self
.
self_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
\
else
None
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
self
.
inter_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
cross_attn
)
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# MLP
if
args
.
num_experts
is
not
None
:
self
.
mlp
=
SwitchMLP
(
init_method
,
output_layer_init_method
)
else
:
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
retriever_output
,
retriever_attn_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
attention_mask
,
inference_params
=
inference_params
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
hidden_states
if
self
.
drop_path
is
None
:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
attention_output
+
attention_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
layernorm_input
=
residual
+
self
.
drop_path
(
out
)
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
args
=
get_args
()
retro_args
=
get_retro_args
()
chunk_length
=
retro_args
.
retro_gpt_chunk_length
ns
,
bs
,
d
=
layernorm_output
.
shape
l
=
int
(
np
.
ceil
(
ns
/
chunk_length
))
pad
=
(
ns
-
1
)
%
chunk_length
attending_chunks
=
layernorm_output
[
pad
:]
padded_chunks
=
torch
.
nn
.
functional
.
pad
(
attending_chunks
,
(
0
,
0
,
0
,
0
,
0
,
chunk_length
-
1
),
'constant'
,
0
)
padded_chunked_output
=
padded_chunks
\
.
reshape
(
l
,
chunk_length
,
bs
,
d
)
\
.
permute
(
1
,
2
,
0
,
3
)
padded_chunked_output
=
padded_chunked_output
.
reshape
(
chunk_length
,
bs
*
l
,
d
).
contiguous
()
# Encoder output.
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
padded_chunked_output
,
None
,
encoder_output
=
retriever_output
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
attention_output
),
torch
.
zeros_like
(
attention_output
),
self
.
hidden_dropout
)
layernorm_input
=
layernorm_input
\
.
reshape
(
chunk_length
,
bs
,
l
,
d
)
\
.
permute
(
2
,
0
,
1
,
3
)
# [l, m, bs, d]
layernorm_input
=
layernorm_input
.
reshape
(
chunk_length
*
l
,
bs
,
d
)
layernorm_input
=
torch
.
nn
.
functional
.
pad
(
layernorm_input
,
(
0
,
0
,
0
,
0
,
pad
,
0
),
'constant'
,
0
)[:
ns
]
layernorm_input
=
layernorm_input
+
residual
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
if
self
.
drop_path
is
None
:
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
+
mlp_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
return
output
class
ParallelRetroEncoderTransformerCALayer
(
MegatronModule
):
"""A single transformer layer for Retro Encoder with cross attention.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
drop_path_rate
=
0.
):
args
=
get_args
()
super
(
ParallelRetroEncoderTransformerCALayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# Self attention.
self
.
self_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
)
self
.
self_attention
.
attention_dropout
=
\
torch
.
nn
.
Dropout
(
args
.
retro_encoder_attention_dropout
)
self
.
hidden_dropout
=
args
.
retro_encoder_hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
\
else
None
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
self
.
inter_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
cross_attn
)
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# MLP
if
args
.
num_experts
is
not
None
:
self
.
mlp
=
SwitchMLP
(
init_method
,
output_layer_init_method
)
else
:
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
retriever_output
,
retriever_attn_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
attention_mask
,
inference_params
=
inference_params
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
hidden_states
if
self
.
drop_path
is
None
:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
attention_output
+
attention_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
layernorm_input
=
residual
+
self
.
drop_path
(
out
)
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
# Neighbors.
args
=
get_args
()
retro_args
=
get_retro_args
()
retrieved_length
=
retro_args
.
retro_gpt_retrieved_length
num_neighbors
=
args
.
retro_num_neighbors
ns
,
bs
,
d
=
layernorm_output
.
shape
# [r, bs * l * k, d]
chunked_outputs
=
layernorm_output
.
reshape
(
retrieved_length
,
-
1
,
num_neighbors
,
d
)
chunked_outputs_before_layer_norm
=
\
layernorm_input
.
reshape
(
retrieved_length
,
-
1
,
num_neighbors
,
d
)
# [r, bs * l, k, d]
layernorm_inputs
=
[]
layernorm_outputs
=
[]
for
k
in
range
(
num_neighbors
):
chunked_output
=
chunked_outputs
[:,:,
k
].
contiguous
()
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
chunked_output
,
# Q (neighbor embedding)
None
,
encoder_output
=
retriever_output
)
# K, V (hidden act)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
chunked_output
else
:
residual
=
chunked_outputs_before_layer_norm
[:,:,
k
]
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
layernorm_inputs
.
append
(
layernorm_input
)
# Layer norm post the decoder attention
layernorm_output
=
\
self
.
post_inter_attention_layernorm
(
layernorm_input
)
layernorm_outputs
.
append
(
layernorm_output
)
# layernorm_input : [r, k * bs * l, d]
# layernorm_output : [r, k * bs * l, d]
layernorm_input
=
\
torch
.
stack
(
layernorm_inputs
,
dim
=
1
).
reshape
(
ns
,
bs
,
d
)
layernorm_output
=
\
torch
.
stack
(
layernorm_outputs
,
dim
=
1
).
reshape
(
ns
,
bs
,
d
)
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
if
self
.
drop_path
is
None
:
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
+
mlp_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
return
output
class
ParallelTransformerLayer
(
MegatronModule
):
"""A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
drop_path_rate
=
0.
):
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# Self attention.
self
.
self_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
\
else
None
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
cross_attn
)
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# MLP
if
args
.
num_experts
is
not
None
:
self
.
mlp
=
SwitchMLP
(
init_method
,
output_layer_init_method
)
else
:
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
attention_mask
,
inference_params
=
inference_params
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
hidden_states
if
self
.
drop_path
is
None
:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
attention_output
+
attention_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
layernorm_input
=
residual
+
self
.
drop_path
(
out
)
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
if
self
.
layer_type
==
LayerType
.
decoder
:
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
layernorm_output
,
enc_dec_attn_mask
,
encoder_output
=
encoder_output
)
# residual connection
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
if
self
.
drop_path
is
None
:
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
+
mlp_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
return
output
class
ParallelRetroEncoder
(
MegatronModule
):
""" Retro Transformer class for encoder ."""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
True
,
post_process
=
True
,
drop_path_rate
=
0.0
):
super
(
ParallelRetroEncoder
,
self
).
__init__
()
args
=
get_args
()
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
self
.
drop_path_rate
=
drop_path_rate
# Store activation checkpoiting flag.
self
.
recompute_granularity
=
args
.
recompute_granularity
self
.
recompute_method
=
args
.
recompute_method
self
.
recompute_num_layers
=
args
.
recompute_num_layers
self
.
distribute_saved_activations
=
\
args
.
distribute_saved_activations
and
not
args
.
sequence_parallel
self
.
sequence_parallel
=
args
.
sequence_parallel
# Number of layers.
self
.
num_layers
=
args
.
retro_encoder_layers
self
.
drop_path_rates
=
[
rate
.
item
()
for
rate
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
args
.
num_layers
)]
if
args
.
retro_add_retriever
:
self
.
P
=
[
1
]
# Transformer layers.
assert
args
.
retro_add_retriever
def
build_layer
(
layer_number
):
if
layer_number
in
self
.
P
:
return
ParallelRetroEncoderTransformerCALayer
(
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
else
:
layer
=
ParallelTransformerLayer
(
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
layer
.
self_attention
.
attention_dropout
=
\
torch
.
nn
.
Dropout
(
args
.
retro_encoder_attention_dropout
)
layer
.
hidden_dropout
=
args
.
retro_encoder_hidden_dropout
return
layer
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'virtual_pipeline_model_parallel_size'
assert
args
.
model_type
!=
ModelType
.
encoder_and_decoder
# Number of layers in each model chunk is the number of layers in
# the stage, divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an
# assignment of layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an
# assignment of layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset
=
parallel_state
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
(
parallel_state
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
# Each stage gets a contiguous set of layers.
if
args
.
model_type
==
ModelType
.
encoder_and_decoder
and
\
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
:
pipeline_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
if
layer_type
==
LayerType
.
encoder
:
offset
=
pipeline_rank
*
self
.
num_layers
else
:
num_ranks_in_enc
=
args
.
pipeline_model_parallel_split_rank
offset
=
(
pipeline_rank
-
num_ranks_in_enc
)
*
self
.
num_layers
else
:
offset
=
parallel_state
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
if
self
.
num_layers
==
0
:
# When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self
.
num_layers
=
1
self
.
layers
=
torch
.
nn
.
ModuleList
([
NoopTransformerLayer
(
1
)
])
else
:
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
if
self
.
post_process
:
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
):
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
):
def
custom_forward
(
*
inputs
):
x_
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
encoder_output
=
inputs
[
2
]
enc_dec_attn_mask
=
inputs
[
3
]
for
index
in
range
(
start
,
end
):
layer
=
self
.
_get_layer
(
index
)
x_
=
layer
(
x_
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
return
x_
return
custom_forward
if
self
.
activations_checkpoint_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and
# checkpoint the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l
=
0
while
l
<
self
.
num_layers
:
hidden_states
=
parallel_state
.
checkpoint
(
custom
(
l
,
l
+
self
.
activations_checkpoint_num_layers
),
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
activations_checkpoint_num_layers
elif
self
.
activations_checkpoint_method
==
'block'
:
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
activations_checkpoint_num_layers
:
hidden_states
=
parallel_state
.
checkpoint
(
custom
(
l
,
l
+
1
),
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
raise
ValueError
(
"Invalid activation checkpoint method."
)
return
hidden_states
def
set_input_tensor
(
self
,
input_tensor
):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
hidden_states
,
attention_mask
,
retriever_output
,
retriever_attn_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# Checks.
if
inference_params
:
assert
self
.
activations_checkpoint_method
is
None
,
\
'inference does not work with activation checkpointing'
if
self
.
pre_process
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
fp32_residual_connection
:
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
().
float
()
# Otherwise, leave it as is.
else
:
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
else
:
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states
=
core_utils
.
make_viewless_tensor
(
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
,
)
# Transpose encoder output.
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
args
=
get_args
()
assert
not
args
.
sequence_parallel
,
"if SP, need rng context."
# Forward pass.
if
self
.
recompute_granularity
==
'full'
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
if
index
+
1
in
self
.
P
:
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
retriever_output
=
retriever_output
,
retriever_attn_mask
=
retriever_attn_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
else
:
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
# Final layer norm.
if
self
.
post_process
:
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
self
.
final_layernorm
(
hidden_states
)
else
:
output
=
hidden_states
return
output
class
ParallelRetroTransformer
(
MegatronModule
):
"""Standard GPT Transformer class."""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
True
,
post_process
=
True
,
drop_path_rate
=
0.0
,
retriever
=
None
):
super
(
ParallelRetroTransformer
,
self
).
__init__
()
args
=
get_args
()
assert
pre_process
and
post_process
,
\
"pipeline parallelism un-supported."
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
self
.
drop_path_rate
=
drop_path_rate
# Store activation checkpoiting flag.
self
.
recompute_granularity
=
args
.
recompute_granularity
self
.
recompute_method
=
args
.
recompute_method
self
.
recompute_num_layers
=
args
.
recompute_num_layers
self
.
distribute_saved_activations
=
\
args
.
distribute_saved_activations
and
not
args
.
sequence_parallel
self
.
sequence_parallel
=
args
.
sequence_parallel
# Number of layers.
self
.
num_layers
=
_get_num_layers
(
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
self
.
drop_path_rates
=
[
rate
.
item
()
for
rate
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
args
.
num_layers
)]
if
args
.
retro_add_retriever
:
if
args
.
num_layers
==
12
:
self
.
P
=
[
6
,
9
,
12
]
elif
args
.
num_layers
==
24
:
self
.
P
=
np
.
arange
(
9
,
25
,
3
).
tolist
()
elif
args
.
num_layers
==
40
:
self
.
P
=
np
.
arange
(
9
,
41
,
3
).
tolist
()
self
.
P
.
append
(
40
)
self
.
retriever
=
retriever
# Transformer layers.
assert
args
.
retro_add_retriever
def
build_layer
(
layer_number
):
if
layer_number
==
min
(
self
.
P
):
return
ParallelRetroTransformerEncoderLayer
(
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
],
retriever
=
retriever
)
elif
layer_number
in
self
.
P
:
return
ParallelRetroTransformerLayer
(
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
else
:
return
ParallelTransformerLayer
(
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'virtual_pipeline_model_parallel_size'
assert
args
.
model_type
!=
ModelType
.
encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset
=
parallel_state
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
(
parallel_state
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
# Each stage gets a contiguous set of layers.
if
args
.
model_type
==
ModelType
.
encoder_and_decoder
and
\
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
:
pipeline_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
if
layer_type
==
LayerType
.
encoder
:
offset
=
pipeline_rank
*
self
.
num_layers
else
:
num_ranks_in_enc
=
args
.
pipeline_model_parallel_split_rank
offset
=
(
pipeline_rank
-
num_ranks_in_enc
)
*
self
.
num_layers
else
:
offset
=
parallel_state
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
if
self
.
num_layers
==
0
:
# When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self
.
num_layers
=
1
self
.
layers
=
torch
.
nn
.
ModuleList
([
NoopTransformerLayer
(
1
)
])
else
:
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
if
self
.
post_process
:
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
):
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
):
def
custom_forward
(
*
inputs
):
x_
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
encoder_output
=
inputs
[
2
]
enc_dec_attn_mask
=
inputs
[
3
]
for
index
in
range
(
start
,
end
):
layer
=
self
.
_get_layer
(
index
)
x_
=
layer
(
x_
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
return
x_
return
custom_forward
if
self
.
activations_checkpoint_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l
=
0
while
l
<
self
.
num_layers
:
hidden_states
=
parallel_state
.
checkpoint
(
custom
(
l
,
l
+
self
.
activations_checkpoint_num_layers
),
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
activations_checkpoint_num_layers
elif
self
.
activations_checkpoint_method
==
'block'
:
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
activations_checkpoint_num_layers
:
hidden_states
=
parallel_state
.
checkpoint
(
custom
(
l
,
l
+
1
),
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
raise
ValueError
(
"Invalid activation checkpoint method."
)
return
hidden_states
def
set_input_tensor
(
self
,
input_tensor
):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
hidden_states
,
attention_mask
,
retriever_output
=
None
,
retriever_attn_mask
=
None
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# Checks.
if
inference_params
:
assert
self
.
recompute_granularity
is
None
,
\
'inference does not work with activation checkpointing'
args
=
get_args
()
# Transpose retriever output, to match hidden_states shape.
retriever_output
=
retriever_output
.
transpose
(
0
,
1
).
contiguous
()
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states
=
core_utils
.
make_viewless_tensor
(
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
,
)
# Transpose encoder output.
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
# Forward pass.
assert
not
args
.
sequence_parallel
,
"if SP, need rng context."
if
self
.
recompute_granularity
==
'full'
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
if
args
.
retro_add_retriever
and
index
+
1
==
min
(
self
.
P
):
hidden_states
,
E
=
layer
(
hidden_states
,
attention_mask
,
retriever_output
=
retriever_output
,
retriever_attn_mask
=
retriever_attn_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
elif
args
.
retro_add_retriever
and
index
+
1
in
self
.
P
:
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
retriever_output
=
E
,
retriever_attn_mask
=
retriever_attn_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
else
:
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
# Final layer norm.
output
=
self
.
final_layernorm
(
hidden_states
)
return
output
megatron/model/t5_model.py
View file @
3aca1415
...
...
@@ -11,9 +11,7 @@ from megatron.model.language_model import parallel_lm_logits, get_language_model
from
megatron.model
import
LayerNorm
from
megatron.model.utils
import
(
openai_gelu
,
get_linear_layer
,
init_method_normal
,
scaled_init_method_normal
get_linear_layer
)
from
.module
import
MegatronModule
...
...
@@ -43,17 +41,12 @@ class T5LMHead(MegatronModule):
Arguments:
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: wether output logits being distributed or not.
"""
def
__init__
(
self
,
mpu_vocab_size
,
parallel_output
):
super
(
T5LMHead
,
self
).
__init__
()
args
=
get_args
()
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
self
.
bias
.
model_parallel
=
True
self
.
bias
.
partition_dim
=
0
...
...
@@ -72,41 +65,38 @@ class T5Model(MegatronModule):
"""T5 Language model."""
def
__init__
(
self
,
config
,
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
,
add_encoder
=
True
,
add_decoder
=
True
):
super
(
T5Model
,
self
).
__init__
(
)
super
(
).
__init__
(
config
=
config
)
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
parallel_output
=
parallel_output
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
add_encoder
=
add_encoder
self
.
add_decoder
=
add_decoder
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
initialize_word_embeddings
(
init_method_normal
)
self
.
initialize_word_embeddings
()
if
self
.
post_process
and
self
.
add_decoder
:
self
.
lm_head
=
T5LMHead
(
self
.
wor
d_embedding
s
_weight
().
size
(
0
),
self
.
share
d_embedding
_or_output
_weight
().
size
(
0
),
parallel_output
)
self
.
_lm_head_key
=
'lm_head'
...
...
@@ -139,7 +129,7 @@ class T5Model(MegatronModule):
decoder_output
,
encoder_output
=
lm_output
# Output. [s, b, h]
lm_logits
=
self
.
lm_head
(
decoder_output
,
self
.
wor
d_embedding
s
_weight
())
self
.
share
d_embedding
_or_output
_weight
())
if
lm_labels
is
None
:
# [s b h] => [b s h]
...
...
megatron/model/transformer.py
View file @
3aca1415
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
"""Transformer."""
import
math
from
contextlib
import
nullcontext
import
math
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
typing
import
Optional
from
megatron
import
get_timers
,
get_args
,
core
,
get_num_microbatches
from
megatron
import
get_timers
,
get_args
,
get_retro_args
,
core
,
get_num_microbatches
from
.module
import
MegatronModule
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core.enums
import
ModelType
...
...
@@ -15,7 +16,7 @@ from megatron.model import LayerNorm
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.
model
.rotary_pos_embedding
import
apply_rotary_pos_emb
from
megatron.
core.models.common
.rotary_pos_embedding
import
apply_rotary_pos_emb
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
try
:
...
...
@@ -26,7 +27,10 @@ except ImportError:
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_func
except
ImportError
:
flash_attn_unpadded_func
=
None
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_func
as
flash_attn_unpadded_func
except
ImportError
:
flash_attn_unpadded_func
=
None
""" We use the following notation throughout this file:
h: hidden size
...
...
@@ -65,18 +69,6 @@ class DropPath(MegatronModule):
output
=
hidden_state
.
div
(
keep_prob
)
*
random_tensor
return
output
def
_args_to_kwargs
():
args
=
get_args
()
common_kwargs
=
{
"params_dtype"
:
args
.
params_dtype
,
"use_cpu_initialization"
:
args
.
use_cpu_initialization
,
"perform_initialization"
:
args
.
perform_initialization
,
"gradient_accumulation_fusion"
:
args
.
gradient_accumulation_fusion
,
"sequence_parallel_enabled"
:
args
.
sequence_parallel
,
}
return
common_kwargs
class
ParallelMLP
(
MegatronModule
):
"""MLP.
...
...
@@ -85,22 +77,26 @@ class ParallelMLP(MegatronModule):
state back into h hidden dimension.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
def
__init__
(
self
,
config
):
super
(
ParallelMLP
,
self
).
__init__
()
args
=
get_args
()
self
.
add_bias
=
args
.
add_bias_linear
self
.
add_bias
=
config
.
add_bias_linear
ffn_hidden_size
=
config
.
ffn_hidden_size
if
config
.
gated_linear_unit
:
ffn_hidden_size
*=
2
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self
.
dense_h_to_4h
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
ffn_hidden_size
*
2
if
args
.
swiglu
else
args
.
ffn_hidden_size
,
config
.
hidden_size
,
ffn_hidden_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
self
.
add_bias
,
gather_output
=
False
,
init_method
=
init_method
,
skip_bias_add
=
True
,
async_tensor_model_parallel_allreduce
=
args
.
async_tensor_model_parallel_allreduce
,
**
_args_to_kwargs
())
)
self
.
bias_gelu_fusion
=
False
self
.
activation_func
=
None
...
...
@@ -125,13 +121,13 @@ class ParallelMLP(MegatronModule):
# Project back to h.
self
.
dense_4h_to_h
=
tensor_parallel
.
RowParallelLinear
(
args
.
ffn_hidden_size
,
args
.
hidden_size
,
config
.
ffn_hidden_size
,
config
.
hidden_size
,
config
=
config
,
init_method
=
config
.
output_layer_init_method
,
bias
=
self
.
add_bias
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
,
**
_args_to_kwargs
())
input_is_parallel
=
True
)
def
forward
(
self
,
hidden_states
):
...
...
@@ -155,13 +151,13 @@ class SwitchMLP(MegatronModule):
"""
Routes input to one of N MLP "experts"
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
def
__init__
(
self
,
config
):
super
(
SwitchMLP
,
self
).
__init__
()
args
=
get_args
()
self
.
router
=
torch
.
nn
.
Linear
(
args
.
hidden_size
,
args
.
num_experts
)
self
.
router
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
args
.
num_experts
)
self
.
experts
=
torch
.
nn
.
ModuleList
()
for
i
in
range
(
args
.
num_experts
):
self
.
experts
.
append
(
ParallelMLP
(
init_method
,
output_layer_init_method
))
self
.
experts
.
append
(
ParallelMLP
(
config
))
def
forward
(
self
,
hidden_states
):
# hidden_states: [s, b, h]
...
...
@@ -188,45 +184,48 @@ class SwitchMLP(MegatronModule):
local_indices
=
(
max_ind
==
expert_num
).
nonzero
()
hidden
=
hidden_states
[
local_indices
,:]
output
,
output_bias
=
expert
(
hidden
)
output_bias
=
output_bias
.
expand_as
(
output
)
if
output_bias
is
not
None
:
output_bias
=
output_bias
.
expand_as
(
output
)
output_bias_total
[
local_indices
,:]
=
output_bias
output_total
[
local_indices
,:]
=
output
output_bias_total
[
local_indices
,:]
=
output_bias
output_total
=
output_total
*
max_prob
output_bias_total
=
output_bias_total
*
max_prob
output_total
=
output_total
.
view
(
s
,
b
,
h
)
output_bias_total
=
output_bias_total
.
view
(
s
,
b
,
h
)
if
output_bias
is
not
None
:
output_bias_total
=
output_bias_total
*
max_prob
output_bias_total
=
output_bias_total
.
view
(
s
,
b
,
h
)
else
:
output_bias_total
=
None
return
output_total
,
output_bias_total
class
CoreAttention
(
MegatronModule
):
def
__init__
(
self
,
layer_number
,
def
__init__
(
self
,
layer_number
,
config
,
attn_mask_type
=
AttnMaskType
.
padding
):
super
(
CoreAttention
,
self
).
__init__
()
args
=
get_args
()
self
.
fp16
=
args
.
fp16
self
.
bf16
=
args
.
bf16
self
.
fp16
=
config
.
fp16
self
.
bf16
=
config
.
bf16
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
self
.
apply_query_key_layer_scaling
=
config
.
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
config
.
attention_softmax_in_fp32
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attn_mask_type
=
attn_mask_type
self
.
sequence_parallel
=
args
.
sequence_parallel
self
.
sequence_parallel
=
config
.
sequence_parallel
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
projection_size
=
config
.
kv_channels
*
config
.
num_attention_heads
# Per attention head and per partition values.
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_partition
=
core
.
utils
.
divide
(
projection_size
,
world_size
)
self
.
hidden_size_per_attention_head
=
core
.
utils
.
divide
(
projection_size
,
args
.
num_attention_heads
)
projection_size
,
config
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
core
.
utils
.
divide
(
args
.
num_attention_heads
,
world_size
)
config
.
num_attention_heads
,
world_size
)
coeff
=
None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
...
...
@@ -237,7 +236,7 @@ class CoreAttention(MegatronModule):
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
self
.
bf16
,
self
.
attn_mask_type
,
args
.
masked_softmax_fusion
,
config
.
masked_softmax_fusion
,
attention_mask_func
,
self
.
attention_softmax_in_fp32
,
coeff
)
...
...
@@ -245,7 +244,7 @@ class CoreAttention(MegatronModule):
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self
.
attention_dropout
=
torch
.
nn
.
Dropout
(
args
.
attention_dropout
)
self
.
attention_dropout
=
torch
.
nn
.
Dropout
(
config
.
attention_dropout
)
def
forward
(
self
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
):
...
...
@@ -261,8 +260,8 @@ class CoreAttention(MegatronModule):
key_layer
.
size
(
0
))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
query_layer
=
query_layer
.
reshape
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
...
...
@@ -379,17 +378,18 @@ class FlashSelfAttention(torch.nn.Module):
is_causal
=
self
.
causal
cu_seqlens_k
=
cu_seqlens_q
dropout_p
=
self
.
dropout_p
else
:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal
=
seqlen_q
==
seqlen_k
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
self
.
dropout_p
=
0
dropout_p
=
0
output
=
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
seqlen_q
,
seqlen_k
,
self
.
dropout_p
,
dropout_p
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
is_causal
)
...
...
@@ -404,8 +404,7 @@ class ParallelAttention(MegatronModule):
and returns output of the same size.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
def
__init__
(
self
,
config
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
AttnMaskType
.
padding
):
super
(
ParallelAttention
,
self
).
__init__
()
...
...
@@ -413,10 +412,21 @@ class ParallelAttention(MegatronModule):
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
self
.
params_dtype
=
args
.
params_dtype
self
.
sequence_parallel
=
args
.
sequence_parallel
self
.
params_dtype
=
config
.
params_dtype
self
.
sequence_parallel
=
config
.
sequence_parallel
self
.
group_query_attention
=
args
.
group_query_attention
self
.
num_query_groups
=
args
.
num_query_groups
query_projection_size
=
config
.
kv_channels
*
config
.
num_attention_heads
if
self
.
group_query_attention
:
kv_projection_size
=
args
.
kv_channels
*
args
.
num_query_groups
else
:
kv_projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
self
.
use_flash_attn
=
args
.
use_flash_attn
self
.
use_flash_attn
=
args
.
use_flash_attn
\
and
attention_type
==
AttnType
.
self_attn
\
and
self
.
attn_mask_type
==
AttnMaskType
.
causal
if
self
.
use_flash_attn
:
if
flash_attn_unpadded_func
is
None
:
raise
ImportError
(
'FlashAttention is not installed, please install with '
...
...
@@ -428,64 +438,72 @@ class ParallelAttention(MegatronModule):
if
rearrange
is
None
:
raise
ImportError
(
'einops is not installed, please install with pip install einops'
)
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
# Per attention head and per partition values.
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_attention_head
=
core
.
utils
.
divide
(
projection_size
,
args
.
num_attention_heads
)
query_
projection_size
,
config
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
core
.
utils
.
divide
(
args
.
num_attention_heads
,
world_size
)
config
.
num_attention_heads
,
world_size
)
if
self
.
group_query_attention
:
if
args
.
num_query_groups
%
world_size
!=
0
:
raise
NotImplementedError
(
'Currently the num_query_groups should be '
'a multiple of the tensor parallel size'
)
self
.
num_query_groups_per_partition
=
core
.
utils
.
divide
(
args
.
num_query_groups
,
world_size
)
else
:
self
.
num_query_groups_per_partition
=
self
.
num_attention_heads_per_partition
# Strided linear layer.
if
attention_type
==
AttnType
.
self_attn
:
self
.
query_key_value
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
3
*
projection_size
,
config
.
hidden_size
,
query_projection_size
+
2
*
kv_projection_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
args
.
add_bias_linear
,
gather_output
=
False
,
init_method
=
init_method
,
async_tensor_model_parallel_allreduce
=
args
.
async_tensor_model_parallel_allreduce
,
**
_args_to_kwargs
())
gather_output
=
False
)
else
:
assert
attention_type
==
AttnType
.
cross_attn
self
.
query
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
projection_size
,
bias
=
args
.
add_bias_linear
,
gather_output
=
False
,
init_method
=
init_method
,
async_tensor_model_parallel_allreduce
=
args
.
async_tensor_model_parallel_allreduce
,
**
_args_to_kwargs
())
if
self
.
group_query_attention
:
raise
NotImplementedError
(
"Grouped query attention not implemented for cross-attention."
)
assert
query_projection_size
==
kv_projection_size
self
.
key_value
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
2
*
projection_size
,
bias
=
args
.
add_bias_linear
,
gather_output
=
False
,
init_method
=
init_method
,
async_tensor_model_parallel_allreduce
=
args
.
async_tensor_model_parallel_allreduce
,
**
_args_to_kwargs
())
self
.
query
=
tensor_parallel
.
ColumnParallelLinear
(
config
.
hidden_size
,
query_projection_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
config
.
add_bias_linear
,
gather_output
=
False
)
self
.
core_attention
=
CoreAttention
(
self
.
layer_number
,
self
.
key_value
=
tensor_parallel
.
ColumnParallelLinear
(
config
.
hidden_size
,
2
*
kv_projection_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
config
.
add_bias_linear
,
gather_output
=
False
)
self
.
core_attention
=
CoreAttention
(
self
.
layer_number
,
config
,
self
.
attn_mask_type
)
self
.
checkpoint_core_attention
=
args
.
recompute_granularity
==
'selective'
self
.
checkpoint_core_attention
=
config
.
recompute_granularity
==
'selective'
if
self
.
use_flash_attn
:
self
.
core_attention_flash
=
FlashSelfAttention
(
causal
=
True
,
attention_dropout
=
args
.
attention_dropout
causal
=
True
,
attention_dropout
=
config
.
attention_dropout
)
# Output.
self
.
dense
=
tensor_parallel
.
RowParallelLinear
(
projection_size
,
args
.
hidden_size
,
query_projection_size
,
config
.
hidden_size
,
config
=
config
,
init_method
=
config
.
output_layer_init_method
,
bias
=
args
.
add_bias_linear
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
,
**
_args_to_kwargs
())
skip_bias_add
=
True
)
def
_checkpointed_attention_forward
(
self
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
,
...
...
@@ -510,11 +528,11 @@ class ParallelAttention(MegatronModule):
return
hidden_states
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
):
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
,
num_attention_heads
):
return
torch
.
empty
(
inference_max_sequence_len
,
batch_size
,
self
.
num_attention_heads
_per_partition
,
num_attention_heads
,
self
.
hidden_size_per_attention_head
,
dtype
=
self
.
params_dtype
,
device
=
torch
.
cuda
.
current_device
())
...
...
@@ -530,12 +548,15 @@ class ParallelAttention(MegatronModule):
is_first_step
=
False
if
inference_params
:
if
self
.
layer_number
not
in
inference_params
.
key_value_memory_dict
:
inf_max_seq_len
=
inference_params
.
max_sequence_len
inf_max_seq_len
=
inference_params
.
max_sequence_len
gth
inf_max_batch_size
=
inference_params
.
max_batch_size
inference_key_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
)
inf_max_seq_len
,
inf_max_batch_size
,
self
.
num_query_groups_per_partition
)
inference_value_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
)
inf_max_seq_len
,
inf_max_batch_size
,
self
.
num_query_groups_per_partition
)
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_key_memory
,
inference_value_memory
)
is_first_step
=
True
...
...
@@ -546,21 +567,36 @@ class ParallelAttention(MegatronModule):
# =====================
# Query, Key, and Value
# =====================
if
self
.
attention_type
==
AttnType
.
self_attn
:
# Attention heads [sq, b, h] --> [sq, b,
(np * 3
* hn)]
# Attention heads [sq, b, h] --> [sq, b,
ng * (np/ng + 2)
* hn)]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
3
*
self
.
hidden_size_per_attention_head
)
# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
(
self
.
num_query_groups_per_partition
,
(
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
+
2
)
*
self
.
hidden_size_per_attention_head
),
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [sq, b, n
p, 3
* hn] -->
3
[sq, b, n
p
, hn]
# [sq, b, n
g, (np/ng + 2)
* hn] --> [sq, b, n
g, np/ng * hn], [sq, b, ng, hn], [sq, b, ng
, hn]
(
query_layer
,
key_layer
,
value_layer
)
=
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
key_layer
,
value_layer
)
=
torch
.
split
(
mixed_x_layer
,
[
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
*
self
.
hidden_size_per_attention_head
),
self
.
hidden_size_per_attention_head
,
self
.
hidden_size_per_attention_head
],
dim
=
3
)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
query_layer
=
query_layer
.
view
(
query_layer
.
size
(
0
),
query_layer
.
size
(
1
),
-
1
,
self
.
hidden_size_per_attention_head
)
else
:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer
,
_
=
self
.
key_value
(
encoder_output
)
...
...
@@ -568,19 +604,19 @@ class ParallelAttention(MegatronModule):
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape
=
mixed_kv_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
2
*
self
.
hidden_size_per_attention_head
)
2
*
self
.
hidden_size_per_attention_head
)
mixed_kv_layer
=
mixed_kv_layer
.
view
(
*
new_tensor_shape
)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(
key_layer
,
value_layer
)
=
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
value_layer
)
=
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer
,
_
=
self
.
query
(
hidden_states
)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape
=
query_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
self
.
hidden_size_per_attention_head
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ==================================
...
...
@@ -632,11 +668,20 @@ class ParallelAttention(MegatronModule):
k_pos_emb
=
k_pos_emb
[:
sequence_end
,
:,
:,
:]
rotary_pos_emb
=
(
q_pos_emb
,
k_pos_emb
)
# ==================================
# core attention computation
# ==================================
# expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
key_layer
=
key_layer
.
repeat_interleave
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
,
dim
=
2
)
value_layer
=
value_layer
.
repeat_interleave
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
,
dim
=
2
)
# apply relative positional encoding (rotary embedding)
if
rotary_pos_emb
is
not
None
:
q_pos_emb
,
k_pos_emb
=
rotary_pos_emb
...
...
@@ -711,10 +756,11 @@ class ParallelTransformerLayer(MegatronModule):
output of the same size.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
def
__init__
(
self
,
config
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
drop_path_rate
=
0.
):
# retriever=None):
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
...
...
@@ -722,57 +768,59 @@ class ParallelTransformerLayer(MegatronModule):
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
=
config
.
apply_residual_connection_post_layernorm
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
bf16
=
config
.
bf16
self
.
fp32_residual_connection
=
config
.
fp32_residual_connection
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
,
sequence_parallel
=
config
.
sequence_parallel
,
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
# Self attention.
self
.
self_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
config
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
hidden_dropout
=
config
.
hidden_dropout
self
.
bias_dropout_fusion
=
config
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
else
None
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_
persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
,
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
,
no_persist_layer_norm
=
not
config
.
persist_layer_norm
,
sequence_parallel
=
config
.
sequence_parallel
,
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
if
self
.
layer_type
==
LayerType
.
decoder
:
# Cross attention.
if
self
.
layer_type
in
(
LayerType
.
decoder
,
LayerType
.
retro_decoder
,
LayerType
.
retro_decoder_with_retriever
,
LayerType
.
retro_encoder
):
self
.
inter_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
config
,
layer_number
,
attention_type
=
AttnType
.
cross_attn
)
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_
persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
,
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
,
no_persist_layer_norm
=
not
config
.
persist_layer_norm
,
sequence_parallel
=
config
.
sequence_parallel
,
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
# MLP
if
args
.
num_experts
is
not
None
:
self
.
mlp
=
SwitchMLP
(
init_method
,
output_layer_init_method
)
self
.
mlp
=
SwitchMLP
(
config
)
else
:
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
self
.
mlp
=
ParallelMLP
(
config
)
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
...
...
@@ -781,13 +829,245 @@ class ParallelTransformerLayer(MegatronModule):
self
.
bias_dropout_add_exec_handler
=
\
nullcontext
if
use_nvfuser
else
torch
.
enable_grad
if
args
.
retro_add_retriever
:
retro_args
=
get_retro_args
()
self
.
retro_num_neighbors
=
args
.
retro_num_neighbors
self
.
retro_chunk_length
=
retro_args
.
retro_gpt_chunk_length
self
.
retro_retrieved_length
=
retro_args
.
retro_gpt_retrieved_length
# Retriever (bi-directional transformer with cross attention)
if
layer_type
==
LayerType
.
retro_decoder_with_retriever
:
self
.
retriever
=
ParallelTransformer
(
config
=
config
,
model_type
=
ModelType
.
retro_encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
True
,
post_process
=
False
,
)
self
.
_retriever_key
=
'retriever'
else
:
self
.
retriever
=
None
def
default_decoder_cross_attention
(
self
,
encoder_output
,
enc_dec_attn_mask
,
layernorm_input
,
layernorm_output
,
bias_dropout_add_func
):
'''Cross attention for a standard encoder-decoder model.'''
# Attention.
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
layernorm_output
,
enc_dec_attn_mask
,
encoder_output
=
encoder_output
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
if
attention_bias
is
not
None
:
attention_bias
=
attention_bias
.
expand_as
(
residual
)
# Bias-dropout-add.
with
self
.
bias_dropout_add_exec_handler
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
,
residual
,
self
.
hidden_dropout
)
# Layer norm.
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
return
layernorm_input
,
layernorm_output
def
retro_encoder_cross_attention
(
self
,
retriever_output
,
layernorm_input
,
layernorm_output
,
bias_dropout_add_func
):
"""Cross attention for Retro encoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
"""
ns
,
bs
,
d
=
layernorm_output
.
shape
# [r, bs * l * k, d]
# Divide sequence dimension into chunks.
chunked_outputs
=
layernorm_output
.
reshape
(
self
.
retro_retrieved_length
,
-
1
,
self
.
retro_num_neighbors
,
d
)
chunked_outputs_before_layer_norm
=
\
layernorm_input
.
reshape
(
self
.
retro_retrieved_length
,
-
1
,
self
.
retro_num_neighbors
,
d
)
# [r, bs*l, k, d]
# Per-chunk attention.
layernorm_inputs
=
[]
layernorm_outputs
=
[]
for
k
in
range
(
self
.
retro_num_neighbors
):
# Attention.
chunked_output
=
chunked_outputs
[:,:,
k
].
contiguous
()
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
chunked_output
,
# Q (neighbor embedding)
None
,
encoder_output
=
retriever_output
)
# K, V (hidden act)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
chunked_output
else
:
residual
=
chunked_outputs_before_layer_norm
[:,:,
k
]
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
None
if
attention_bias
is
None
else
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
layernorm_inputs
.
append
(
layernorm_input
)
# Layer norm.
layernorm_output
=
\
self
.
post_inter_attention_layernorm
(
layernorm_input
)
layernorm_outputs
.
append
(
layernorm_output
)
# Concatenate layer norms.
# layernorm_input : [r, k * bs * l, d]
# layernorm_output : [r, k * bs * l, d]
layernorm_input
=
\
torch
.
stack
(
layernorm_inputs
,
dim
=
1
).
reshape
(
ns
,
bs
,
d
)
layernorm_output
=
\
torch
.
stack
(
layernorm_outputs
,
dim
=
1
).
reshape
(
ns
,
bs
,
d
)
return
layernorm_input
,
layernorm_output
def
retro_decoder_cross_attention
(
self
,
retriever_input
,
retriever_output
,
retriever_attn_mask
,
layernorm_input
,
layernorm_output
,
inference_params
,
bias_dropout_add_func
):
"""Cross attention for Retro decoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
m : Number of tokens per chunk.
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
"""
ns
,
bs
,
d
=
layernorm_output
.
shape
l
=
int
(
np
.
ceil
(
ns
/
self
.
retro_chunk_length
))
# Retrieve neighbors.
if
self
.
layer_type
==
LayerType
.
retro_decoder_with_retriever
:
first_ns
=
ns
%
self
.
retro_chunk_length
if
first_ns
>
0
:
raise
Exception
(
"test this case."
)
first_chunk
,
rest_chunk
=
\
layernorm_output
[:
first_ns
],
layernorm_output
[
first_ns
:]
first_chunk
=
torch
.
nn
.
functional
.
pad
(
first_chunk
,
(
0
,
0
,
0
,
0
,
0
,
self
.
retro_chunk_length
-
first_ns
),
'constant'
,
0
)
chunked_output
=
\
torch
.
cat
((
first_chunk
,
rest_chunk
),
dim
=
0
)
# [l * m, bs, d]
else
:
chunked_output
=
layernorm_output
# [l * m, bs, d]
chunked_output
=
chunked_output
\
.
reshape
(
l
,
self
.
retro_chunk_length
,
bs
,
d
)
\
.
permute
(
1
,
2
,
0
,
3
)
\
.
reshape
(
self
.
retro_chunk_length
,
bs
*
l
,
d
)
\
.
contiguous
()
# Get Encoder Output
retriever_output
=
self
.
retriever
(
hidden_states
=
retriever_input
,
attention_mask
=
retriever_attn_mask
,
retriever_output
=
chunked_output
,
retriever_attn_mask
=
retriever_attn_mask
,
inference_params
=
inference_params
)
# [r, k * bs * l , d]
retriever_output
=
retriever_output
.
reshape
(
self
.
retro_retrieved_length
*
self
.
retro_num_neighbors
,
bs
*
l
,
d
)
# [r * k, bs * l, d]
# Chunks.
pad
=
(
ns
-
1
)
%
self
.
retro_chunk_length
attending_chunks
=
layernorm_output
[
pad
:]
padded_chunks
=
torch
.
nn
.
functional
.
pad
(
attending_chunks
,
(
0
,
0
,
0
,
0
,
0
,
self
.
retro_chunk_length
-
1
),
'constant'
,
0
)
padded_chunked_output
=
padded_chunks
\
.
reshape
(
l
,
self
.
retro_chunk_length
,
bs
,
d
)
\
.
permute
(
1
,
2
,
0
,
3
)
padded_chunked_output
=
padded_chunked_output
.
reshape
(
self
.
retro_chunk_length
,
bs
*
l
,
d
).
contiguous
()
# Encoder output.
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
padded_chunked_output
,
None
,
encoder_output
=
retriever_output
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
None
if
attention_bias
is
None
else
attention_bias
.
expand_as
(
attention_output
),
torch
.
zeros_like
(
attention_output
),
self
.
hidden_dropout
)
layernorm_input
=
layernorm_input
\
.
reshape
(
self
.
retro_chunk_length
,
bs
,
l
,
d
)
\
.
permute
(
2
,
0
,
1
,
3
)
# [l, m, bs, d]
layernorm_input
=
layernorm_input
.
reshape
(
self
.
retro_chunk_length
*
l
,
bs
,
d
)
layernorm_input
=
torch
.
nn
.
functional
.
pad
(
layernorm_input
,
(
0
,
0
,
0
,
0
,
pad
,
0
),
'constant'
,
0
)[:
ns
]
# [ns, b, d]
layernorm_input
=
layernorm_input
+
residual
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
return
retriever_output
,
layernorm_input
,
layernorm_output
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
):
retriever_input
=
None
,
retriever_output
=
None
,
retriever_attn_mask
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
,
attention_bias
=
\
self
.
self_attention
(
...
...
@@ -832,29 +1112,38 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
if
self
.
layer_type
==
LayerType
.
decoder
:
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
layernorm_output
,
enc_dec_attn_mask
,
encoder_output
=
encoder_output
)
# residual connection
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
if
attention_bias
is
not
None
:
attention_bias
=
attention_bias
.
expand_as
(
residual
)
with
self
.
bias_dropout_add_exec_handler
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
,
residual
,
self
.
hidden_dropout
)
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
# Cross attention.
if
self
.
layer_type
==
LayerType
.
encoder
:
pass
elif
self
.
layer_type
==
LayerType
.
decoder
:
layernorm_input
,
layernorm_output
=
\
self
.
default_decoder_cross_attention
(
encoder_output
,
enc_dec_attn_mask
,
layernorm_input
,
layernorm_output
,
bias_dropout_add_func
)
elif
self
.
layer_type
==
LayerType
.
retro_encoder
:
layernorm_input
,
layernorm_output
=
\
self
.
retro_encoder_cross_attention
(
retriever_output
,
layernorm_input
,
layernorm_output
,
bias_dropout_add_func
)
elif
self
.
layer_type
in
(
LayerType
.
retro_decoder
,
LayerType
.
retro_decoder_with_retriever
):
retriever_output
,
layernorm_input
,
layernorm_output
=
\
self
.
retro_decoder_cross_attention
(
retriever_input
,
retriever_output
,
retriever_attn_mask
,
layernorm_input
,
layernorm_output
,
inference_params
,
bias_dropout_add_func
)
else
:
raise
Exception
(
"Unsupported layer type, '%s'."
%
self
.
layer_type
.
name
)
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
...
...
@@ -893,7 +1182,10 @@ class ParallelTransformerLayer(MegatronModule):
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
return
output
if
self
.
layer_type
==
LayerType
.
retro_decoder_with_retriever
:
return
output
,
retriever_output
else
:
return
output
class
NoopTransformerLayer
(
MegatronModule
):
...
...
@@ -922,9 +1214,12 @@ class NoopTransformerLayer(MegatronModule):
return
hidden_states
.
clone
()
def
_get_num_layers
(
args
,
is_encoder_and_decoder_
model
,
is_decoder
=
False
):
def
_get_num_layers
(
args
,
model
_type
,
is_decoder
=
False
):
"""Compute the number of transformer layers resident on the current rank."""
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
is_encoder_and_decoder_model
=
(
model_type
==
ModelType
.
encoder_and_decoder
)
if
model_type
==
ModelType
.
retro_encoder
:
num_layers
=
args
.
retro_encoder_layers
elif
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
is_encoder_and_decoder_model
:
assert
args
.
pipeline_model_parallel_split_rank
is
not
None
...
...
@@ -974,51 +1269,91 @@ def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
return
num_layers
def
_get_layer_type
(
model_type
,
default_layer_type
,
retro_layer_numbers
,
layer_number
):
args
=
get_args
()
if
args
.
retro_add_retriever
and
layer_number
in
retro_layer_numbers
:
if
model_type
==
ModelType
.
retro_decoder
:
return
LayerType
.
retro_decoder_with_retriever
\
if
layer_number
==
retro_layer_numbers
[
0
]
\
else
LayerType
.
retro_decoder
elif
model_type
==
ModelType
.
retro_encoder
:
return
LayerType
.
retro_encoder
else
:
raise
Exception
(
"Unsupported model type, '%s'."
%
model_type
)
else
:
return
default_layer_type
class
ParallelTransformer
(
MegatronModule
):
"""Transformer class."""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
def
__init__
(
self
,
config
,
model_type
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
post_layer_norm
=
True
,
pre_process
=
True
,
post_process
=
True
,
pre_process
=
True
,
post_process
=
True
,
drop_path_rate
=
0.0
):
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
self
.
layer_type
=
layer_type
self
.
model_type
=
args
.
model_type
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
model_type
=
model_type
self
.
bf16
=
config
.
bf16
self
.
fp32_residual_connection
=
config
.
fp32_residual_connection
self
.
post_layer_norm
=
post_layer_norm
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
self
.
drop_path_rate
=
drop_path_rate
self
.
transformer_impl
=
args
.
transformer_impl
self
.
retro_add_retriever
=
args
.
retro_add_retriever
# Store activation checkpoiting flag.
self
.
recompute_granularity
=
args
.
recompute_granularity
self
.
recompute_method
=
args
.
recompute_method
self
.
recompute_num_layers
=
args
.
recompute_num_layers
self
.
recompute_granularity
=
config
.
recompute_granularity
self
.
recompute_method
=
config
.
recompute_method
self
.
recompute_num_layers
=
config
.
recompute_num_layers
self
.
distribute_saved_activations
=
\
args
.
distribute_saved_activations
and
not
args
.
sequence_parallel
config
.
distribute_saved_activations
and
not
config
.
sequence_parallel
self
.
sequence_parallel
=
args
.
sequence_parallel
self
.
sequence_parallel
=
config
.
sequence_parallel
# Transformer Engine Init.
self
.
transformer_engine_v_0_10
=
False
self
.
transformer_engine_v_0_11
=
False
self
.
transformer_engine_v_0_8
=
False
if
self
.
transformer_impl
==
'transformer_engine'
:
global
transformer_engine
import
transformer_engine
self
.
use_fp8
=
args
.
fp8_e4m3
or
args
.
fp8_hybrid
from
importlib.metadata
import
version
from
pkg_resources
import
packaging
te_version
=
packaging
.
version
.
Version
(
version
(
"transformer-engine"
))
if
te_version
>=
packaging
.
version
.
Version
(
"0.8.0"
):
self
.
transformer_engine_v_0_8
=
True
if
te_version
>=
packaging
.
version
.
Version
(
"0.10.0"
):
self
.
transformer_engine_v_0_10
=
True
if
te_version
>=
packaging
.
version
.
Version
(
"0.11.0"
):
self
.
transformer_engine_v_0_11
=
True
del
version
,
packaging
assert
not
args
.
squared_relu
,
"TransformerEngine does not support squared relu activation."
self
.
use_fp8
=
args
.
fp8
is
not
None
self
.
fp8_recipe
=
None
self
.
fp8_group
=
None
if
self
.
use_fp8
:
self
.
fp8_group
=
mpu
.
get_data_parallel_group
()
if
args
.
fp8_e4m3
:
assert
args
.
transformer_impl
==
'transformer_engine'
,
\
'transformer-engine required for fp8 training and inference'
self
.
fp8_group
=
mpu
.
get_amax_reduction_group
()
if
args
.
fp8
==
"e4m3"
:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
elif
args
.
fp8
_
hybrid
:
elif
args
.
fp8
==
"
hybrid
"
:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
HYBRID
else
:
raise
ValueError
(
"The DelayedScaling recipe only supports E4M3 and HYBRID formats."
)
self
.
fp8_recipe
=
transformer_engine
.
common
.
recipe
.
DelayedScaling
(
margin
=
args
.
fp8_margin
,
interval
=
args
.
fp8_interval
,
...
...
@@ -1030,63 +1365,87 @@ class ParallelTransformer(MegatronModule):
self
.
num_microbatches_in_previous_step
=
-
1
self
.
microbatch_count
=
0
self
.
checkpoint_core_attention
=
args
.
recompute_granularity
==
'selective'
self
.
checkpoint_core_attention
=
config
.
recompute_granularity
==
'selective'
# Number of layers.
self
.
num_layers
=
_get_num_layers
(
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
,
layer_type
==
LayerType
.
decoder
)
self
.
num_layers
=
_get_num_layers
(
args
,
model_type
,
layer_type
==
LayerType
.
decoder
)
self
.
drop_path_rates
=
[
rate
.
item
()
for
rate
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
config
.
num_layers
)]
self
.
drop_path_rates
=
[
rate
.
item
()
for
rate
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
args
.
num_layers
)]
self
.
retro_layer_numbers
=
None
if
model_type
==
ModelType
.
retro_decoder
:
retro_layer_start
=
6
if
config
.
num_layers
<=
15
else
9
self
.
retro_layer_numbers
=
\
np
.
arange
(
retro_layer_start
,
args
.
num_layers
+
1
,
3
).
tolist
()
if
model_type
==
ModelType
.
retro_encoder
:
self
.
retro_layer_numbers
=
[
1
]
# Transformer layers.
if
args
.
retro_add_retriever
:
assert
self
.
recompute_granularity
!=
'full'
,
\
"Full recompute not supported for Retro."
assert
args
.
transformer_impl
==
'local'
,
\
"Transformer engine does not support Retro layers."
def
build_layer
(
layer_number
):
if
args
.
transformer_impl
==
'local'
:
current_layer_type
=
_get_layer_type
(
model_type
,
layer_type
,
self
.
retro_layer_numbers
,
layer_number
)
return
ParallelTransformerLayer
(
init_method
,
output_layer_init_method
,
config
,
layer_number
,
layer_type
=
layer_type
,
layer_type
=
current_
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
else
:
# This argument is only available from TE v0.10 onwards.
extra_transformer_engine_kwargs
=
{}
if
self
.
transformer_engine_v_0_8
:
extra_transformer_engine_kwargs
[
"bias"
]
=
args
.
add_bias_linear
if
self
.
transformer_engine_v_0_10
:
extra_transformer_engine_kwargs
[
"activation"
]
=
"swiglu"
if
args
.
swiglu
else
"gelu"
if
self
.
transformer_engine_v_0_11
:
extra_transformer_engine_kwargs
[
"normalization"
]
=
args
.
normalization
return
transformer_engine
.
pytorch
.
TransformerLayer
(
args
.
hidden_size
,
args
.
ffn_hidden_size
,
args
.
num_attention_heads
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
hidden_dropout
=
args
.
hidden_dropout
,
attention_dropout
=
args
.
attention_dropout
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
config
.
hidden_size
,
config
.
ffn_hidden_size
,
config
.
num_attention_heads
,
layernorm_epsilon
=
config
.
layernorm_epsilon
,
hidden_dropout
=
config
.
hidden_dropout
,
attention_dropout
=
config
.
attention_dropout
,
init_method
=
config
.
init_method
,
output_layer_init_method
=
config
.
output_layer_init_method
,
layer_number
=
layer_number
,
kv_channels
=
args
.
kv_channels
,
kv_channels
=
config
.
kv_channels
,
self_attn_mask_type
=
self_attn_mask_type
.
name
,
tp_group
=
mpu
.
get_tensor_model_parallel_group
(),
get_rng_state_tracker
=
tensor_parallel
.
get_cuda_rng_tracker
,
fuse_wgrad_accumulation
=
args
.
gradient_accumulation_fusion
,
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
,
fuse_wgrad_accumulation
=
config
.
gradient_accumulation_fusion
,
apply_query_key_layer_scaling
=
config
.
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
config
.
attention_softmax_in_fp32
,
seq_length
=
args
.
seq_length
,
micro_batch_size
=
args
.
micro_batch_size
,
sequence_parallel
=
args
.
sequence_parallel
,
params_dtype
=
args
.
params_dtype
,
apply_residual_connection_post_layernorm
=
args
.
apply_residual_connection_post_layernorm
,
sequence_parallel
=
config
.
sequence_parallel
,
params_dtype
=
config
.
params_dtype
,
apply_residual_connection_post_layernorm
=
config
.
apply_residual_connection_post_layernorm
,
output_layernorm
=
False
,
layer_type
=
"encoder"
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
],
set_parallel_mode
=
True
,
fuse_qkv_params
=
True
)
fuse_qkv_params
=
True
,
**
extra_transformer_engine_kwargs
)
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
if
config
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
config
.
num_layers
%
config
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'virtual_pipeline_model_parallel_size'
assert
args
.
model_type
!=
ModelType
.
encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
self
.
num_layers
=
self
.
num_layers
//
config
.
virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
...
...
@@ -1096,7 +1455,7 @@ class ParallelTransformer(MegatronModule):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset
=
mpu
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
config
.
num_layers
//
config
.
virtual_pipeline_model_parallel_size
)
+
\
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
# Each stage gets a contiguous set of layers.
...
...
@@ -1126,13 +1485,24 @@ class ParallelTransformer(MegatronModule):
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
# Update dropout rate for Retro encoder.
if
model_type
==
ModelType
.
retro_encoder
:
for
layer
in
self
.
layers
:
if
layer
.
self_attention
.
use_flash_attn
:
layer
.
self_attention
.
core_attention_flash
.
dropout_p
=
\
torch
.
nn
.
Dropout
(
args
.
retro_encoder_attention_dropout
)
else
:
layer
.
self_attention
.
core_attention
.
attention_dropout
.
p
=
\
args
.
retro_encoder_attention_dropout
layer
.
hidden_dropout
=
args
.
retro_encoder_hidden_dropout
if
self
.
post_process
and
self
.
post_layer_norm
:
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
,
sequence_parallel
=
config
.
sequence_parallel
,
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
def
_get_layer
(
self
,
layer_number
):
...
...
@@ -1142,40 +1512,42 @@ class ParallelTransformer(MegatronModule):
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
,
is_first_microbatch
):
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
,
is_transformer_engine
=
False
):
def
custom
(
start
,
end
):
def
custom_forward
(
*
args
,
**
kwargs
):
x_
,
*
args
=
args
for
index
in
range
(
start
,
end
):
layer
=
self
.
_get_layer
(
index
)
x_
=
layer
(
x_
,
*
args
,
**
kwargs
)
return
x_
def
custom_forward_transformer_engine
(
*
args
,
**
kwargs
):
return
custom_forward
(
*
args
,
is_first_microbatch
=
is_first_microbatch
,
**
kwargs
)
if
not
is_transformer_engine
:
return
custom_forward
else
:
return
custom_forward_transformer_engine
return
custom_forward
te_forward_kwargs
=
{}
if
self
.
transformer_impl
==
'transformer_engine'
:
te_forward_kwargs
[
'is_first_microbatch'
]
=
is_first_microbatch
if
self
.
transformer_engine_v_0_10
:
te_forward_kwargs
[
'rotary_pos_emb'
]
=
rotary_pos_emb
if
self
.
recompute_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and
checkpoint
# the input activation of each divided chunk.
# Uniformly divide the total number of Transformer layers and
#
checkpoint
the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l
=
0
while
l
<
self
.
num_layers
:
if
self
.
transformer_impl
==
'transformer_engine'
:
hidden_states
=
transformer_engine
.
pytorch
.
distributed
.
checkpoint
(
custom
(
l
,
l
+
self
.
recompute_num_layers
,
is_transformer_engine
=
True
),
hidden_states
=
transformer_engine
.
pytorch
.
checkpoint
(
custom
(
l
,
l
+
self
.
recompute_num_layers
),
self
.
distribute_saved_activations
,
tensor_parallel
.
get_cuda_rng_tracker
,
mpu
.
get_tensor_model_parallel_group
(),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
)
enc_dec_attn_mask
,
**
te_forward_kwargs
)
else
:
hidden_states
=
tensor_parallel
.
checkpoint
(
custom
(
l
,
l
+
self
.
recompute_num_layers
),
self
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
None
,
None
,
None
,
None
,
rotary_pos_emb
)
l
+=
self
.
recompute_num_layers
...
...
@@ -1186,28 +1558,30 @@ class ParallelTransformer(MegatronModule):
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
recompute_num_layers
:
if
self
.
transformer_impl
==
'transformer_engine'
:
hidden_states
=
transformer_engine
.
pytorch
.
distributed
.
checkpoint
(
custom
(
l
,
l
+
1
,
is_transformer_engine
=
True
),
hidden_states
=
transformer_engine
.
pytorch
.
checkpoint
(
custom
(
l
,
l
+
1
),
self
.
distribute_saved_activations
,
tensor_parallel
.
get_cuda_rng_tracker
,
mpu
.
get_tensor_model_parallel_group
(),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
)
enc_dec_attn_mask
,
**
te_forward_kwargs
)
else
:
hidden_states
=
tensor_parallel
.
checkpoint
(
custom
(
l
,
l
+
1
),
self
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
None
,
None
,
None
,
None
,
rotary_pos_emb
)
else
:
if
self
.
transformer_impl
==
'transformer_engine'
:
hidden_states
=
custom
(
l
,
l
+
1
,
is_transformer_engine
=
True
)(
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
)
enc_dec_attn_mask
,
**
te_forward_kwargs
)
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
None
,
None
,
None
,
None
,
rotary_pos_emb
)
else
:
raise
ValueError
(
"Invalid activation recompute method."
)
...
...
@@ -1225,7 +1599,11 @@ class ParallelTransformer(MegatronModule):
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
):
retriever_input
=
None
,
retriever_output
=
None
,
retriever_attn_mask
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
):
# hidden_states: [s, b, h]
# Checks.
...
...
@@ -1258,11 +1636,13 @@ class ParallelTransformer(MegatronModule):
keep_graph
=
True
,
)
# RNG context.
if
self
.
sequence_parallel
:
rng_context
=
tensor_parallel
.
get_cuda_rng_tracker
().
fork
()
else
:
rng_context
=
nullcontext
()
# Forward layers.
with
rng_context
:
# The fp8_autocast context manager is a no-op when enabled=True
# The if...else serves to short circuit name resolution for fp8_autocast
...
...
@@ -1290,12 +1670,18 @@ class ParallelTransformer(MegatronModule):
'encoder_output'
:
encoder_output
,
'enc_dec_attn_mask'
:
enc_dec_attn_mask
,
'inference_params'
:
inference_params
,
'rotary_pos_emb'
:
rotary_pos_emb
,
}
if
self
.
transformer_impl
==
'transformer_engine'
:
forward_kwargs
[
'is_first_microbatch'
]
=
is_first_microbatch
forward_kwargs
[
'checkpoint_core_attention'
]
=
self
.
checkpoint_core_attention
if
self
.
transformer_engine_v_0_10
:
forward_kwargs
[
'rotary_pos_emb'
]
=
rotary_pos_emb
else
:
forward_kwargs
[
'rotary_pos_emb'
]
=
rotary_pos_emb
forward_kwargs
[
'retriever_input'
]
=
retriever_input
forward_kwargs
[
'retriever_output'
]
=
retriever_output
forward_kwargs
[
'retriever_attn_mask'
]
=
retriever_attn_mask
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
...
...
@@ -1305,6 +1691,14 @@ class ParallelTransformer(MegatronModule):
attention_mask
,
**
forward_kwargs
)
# First Retro decoder layer returns both hidden_states
# and retriever_output. Make retriever_output available
# to subsequence Retro layers.
if
isinstance
(
hidden_states
,
tuple
):
assert
len
(
hidden_states
)
==
2
hidden_states
,
retriever_output
=
hidden_states
forward_kwargs
[
"retriever_output"
]
=
retriever_output
# Skip counter update for eval and activation checkpointing
if
torch
.
is_grad_enabled
()
and
self
.
training
:
self
.
microbatch_count
+=
1
...
...
megatron/model/vision/classification.py
View file @
3aca1415
...
...
@@ -13,7 +13,7 @@ from megatron.model.module import MegatronModule
class
VitClassificationModel
(
MegatronModule
):
"""Vision Transformer Model."""
def
__init__
(
self
,
num_classes
,
finetune
=
False
,
def
__init__
(
self
,
config
,
num_classes
,
finetune
=
False
,
pre_process
=
True
,
post_process
=
True
):
super
(
VitClassificationModel
,
self
).
__init__
()
args
=
get_args
()
...
...
@@ -24,6 +24,7 @@ class VitClassificationModel(MegatronModule):
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
backbone
=
VitBackbone
(
config
=
config
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
single_token_output
=
True
...
...
megatron/model/vision/dino.py
View file @
3aca1415
...
...
@@ -173,11 +173,12 @@ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
return
schedule
def
get_student_backbone_and_num_features
(
pre_process
=
True
,
post_process
=
True
):
def
get_student_backbone_and_num_features
(
config
,
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
if
args
.
vision_backbone_type
==
'vit'
:
student
=
VitBackbone
(
pre_process
=
pre_process
,
student
=
VitBackbone
(
config
,
pre_process
=
pre_process
,
post_process
=
post_process
,
drop_path_rate
=
0.1
,
single_token_output
=
True
)
...
...
@@ -194,11 +195,12 @@ def get_student_backbone_and_num_features(pre_process=True, post_process=True):
return
student
,
num_features
def
get_teacher_backbone_and_num_features
(
pre_process
=
True
,
post_process
=
True
):
def
get_teacher_backbone_and_num_features
(
config
,
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
if
args
.
vision_backbone_type
==
'vit'
:
teacher
=
VitBackbone
(
pre_process
=
pre_process
,
teacher
=
VitBackbone
(
config
,
pre_process
=
pre_process
,
post_process
=
post_process
,
single_token_output
=
True
)
num_features
=
args
.
hidden_size
...
...
@@ -215,7 +217,7 @@ def get_teacher_backbone_and_num_features(pre_process=True, post_process=True):
class
DINOPretrainModel
(
MegatronModule
):
def
__init__
(
self
,
pre_process
=
True
,
post_process
=
True
):
def
__init__
(
self
,
config
,
pre_process
=
True
,
post_process
=
True
):
super
(
DINOPretrainModel
,
self
).
__init__
()
args
=
get_args
()
self
.
out_dim
=
65536
...
...
@@ -234,7 +236,7 @@ class DINOPretrainModel(MegatronModule):
self
.
momentum_teacher
=
0.996
student_backbone
,
num_features
=
\
get_student_backbone_and_num_features
(
pre_process
,
post_process
)
get_student_backbone_and_num_features
(
config
,
pre_process
,
post_process
)
self
.
student
=
MultiCropWrapper
(
student_backbone
,
...
...
@@ -249,7 +251,7 @@ class DINOPretrainModel(MegatronModule):
)
teacher_backbone
,
num_features
=
\
get_teacher_backbone_and_num_features
(
pre_process
,
post_process
)
get_teacher_backbone_and_num_features
(
config
,
pre_process
,
post_process
)
self
.
teacher
=
MultiCropWrapper
(
teacher_backbone
,
DINOHead
(
num_features
,
self
.
out_dim
)
...
...
megatron/model/vision/inpainting.py
View file @
3aca1415
...
...
@@ -18,14 +18,15 @@ from megatron.model.vision.utils import resize_
class
VitInpaintingModel
(
MegatronModule
):
def
__init__
(
self
,
pre_process
=
True
,
post_process
=
True
):
def
__init__
(
self
,
config
,
pre_process
=
True
,
post_process
=
True
):
super
(
VitInpaintingModel
,
self
).
__init__
()
args
=
get_args
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
hidden_size
=
args
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
backbone
=
VitBackbone
(
config
=
config
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
class_token
=
False
,
...
...
megatron/model/vision/mit_backbone.py
View file @
3aca1415
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# found in the LICENSE file in the root directory of this
# source tree.
# ---------------------------------------------------------------
# Copyright (c) 2023, NVIDIA Corporation. All rights reserved.
import
math
import
torch
import
torch.nn
as
nn
...
...
megatron/model/vision/vit_backbone.py
View file @
3aca1415
...
...
@@ -130,24 +130,17 @@ class VitBackbone(MegatronModule):
"""Vision Transformer Model."""
def
__init__
(
self
,
config
,
pre_process
=
True
,
post_process
=
True
,
class_token
=
True
,
single_token_output
=
False
,
post_layer_norm
=
True
,
drop_path_rate
=
0.0
):
super
(
VitBackbone
,
self
).
__init__
(
share_
word_
embeddings
=
False
)
super
(
VitBackbone
,
self
).
__init__
(
share_embeddings
_and_output_weights
=
False
)
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
if
args
.
init_method_xavier_uniform
:
self
.
init_method
=
torch
.
nn
.
init
.
xavier_uniform_
self
.
scaled_init_method
=
torch
.
nn
.
init
.
xavier_uniform_
else
:
self
.
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
...
...
@@ -202,8 +195,7 @@ class VitBackbone(MegatronModule):
# Transformer
self
.
transformer
=
ParallelTransformer
(
self
.
init_method
,
self
.
scaled_init_method
,
config
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
post_layer_norm
=
self
.
post_layer_norm
,
...
...
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment