Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
051f58f1
Commit
051f58f1
authored
Mar 27, 2024
by
liangjing
Browse files
v1
parent
0024a5c6
Changes
203
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
912 additions
and
2247 deletions
+912
-2247
megatron/global_vars.py
megatron/global_vars.py
+2
-2
megatron/initialize.py
megatron/initialize.py
+121
-72
megatron/model/bert_model.py
megatron/model/bert_model.py
+17
-26
megatron/model/classification.py
megatron/model/classification.py
+3
-5
megatron/model/distributed.py
megatron/model/distributed.py
+2
-2
megatron/model/enums.py
megatron/model/enums.py
+4
-1
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+3
-3
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+12
-13
megatron/model/language_model.py
megatron/model/language_model.py
+78
-93
megatron/model/module.py
megatron/model/module.py
+13
-15
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+3
-5
megatron/model/retro_transformer.py
megatron/model/retro_transformer.py
+0
-1731
megatron/model/t5_model.py
megatron/model/t5_model.py
+7
-17
megatron/model/transformer.py
megatron/model/transformer.py
+627
-233
megatron/model/vision/classification.py
megatron/model/vision/classification.py
+2
-1
megatron/model/vision/dino.py
megatron/model/vision/dino.py
+9
-7
megatron/model/vision/inpainting.py
megatron/model/vision/inpainting.py
+3
-2
megatron/model/vision/mit_backbone.py
megatron/model/vision/mit_backbone.py
+2
-7
megatron/model/vision/vit_backbone.py
megatron/model/vision/vit_backbone.py
+3
-11
megatron/optimizer/clip_grads.py
megatron/optimizer/clip_grads.py
+1
-1
No files found.
megatron/global_vars.py
View file @
051f58f1
...
@@ -80,7 +80,7 @@ def _set_signal_handler():
...
@@ -80,7 +80,7 @@ def _set_signal_handler():
def
set_global_variables
(
args
):
def
set_global_variables
(
args
,
build_tokenizer
=
True
):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
assert
args
is
not
None
assert
args
is
not
None
...
@@ -89,7 +89,7 @@ def set_global_variables(args):
...
@@ -89,7 +89,7 @@ def set_global_variables(args):
set_args
(
args
)
set_args
(
args
)
_build_num_microbatches_calculator
(
args
)
_build_num_microbatches_calculator
(
args
)
if
args
.
vocab_file
or
args
.
tokenizer
_model
:
if
build_
tokenizer
:
_
=
_build_tokenizer
(
args
)
_
=
_build_tokenizer
(
args
)
_set_tensorboard_writer
(
args
)
_set_tensorboard_writer
(
args
)
_set_adlr_autoresume
(
args
)
_set_adlr_autoresume
(
args
)
...
...
megatron/initialize.py
View file @
051f58f1
...
@@ -15,36 +15,40 @@ from megatron import get_adlr_autoresume
...
@@ -15,36 +15,40 @@ from megatron import get_adlr_autoresume
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_tensorboard_writer
from
megatron
import
get_tensorboard_writer
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.arguments
import
(
parse_args
,
validate_args
)
from
megatron.arguments
import
parse_args
,
validate_args
from
megatron.checkpointing
import
load_args_from_checkpoint
from
megatron.checkpointing
import
load_args_from_checkpoint
from
megatron.global_vars
import
set_global_variables
from
megatron.global_vars
import
set_global_variables
from
megatron.model.transformer
import
bias_dropout_add_fused_train
from
megatron.model.transformer
import
bias_dropout_add_fused_train
from
megatron.model.fused_bias_gelu
import
bias_gelu
from
megatron.model.fused_bias_gelu
import
bias_gelu
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
def
initialize_megatron
(
ignore_unknown_args
=
False
,
allow_no_cuda
=
False
):
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
,
allow_no_cuda
=
False
,
):
"""Set global variables, initialize distributed, and
"""Set global variables, initialize distributed, and
set autoresume and random seeds.
set autoresume and random seeds.
`allow_no_cuda` should not be set unless using megatron for cpu only
`allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know
data processing. In general this arg should not be set unless you know
what you are doing.
what you are doing.
Returns a function to finalize distributed env initialization
Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True)
(optionally, only when args.lazy_mpu_init == True)
"""
"""
if
not
allow_no_cuda
:
if
not
allow_no_cuda
:
# Make sure cuda is available.
# Make sure cuda is available.
assert
torch
.
cuda
.
is_available
(),
'
Megatron requires CUDA.
'
assert
torch
.
cuda
.
is_available
(),
"
Megatron requires CUDA.
"
# Parse arguments
# Parse arguments
args
=
parse_args
(
extra_args_provider
,
ignore_unknown_args
)
args
=
parse_args
(
extra_args_provider
,
ignore_unknown_args
)
if
args
.
use_checkpoint_args
or
args_defaults
.
get
(
'
use_checkpoint_args
'
,
False
):
if
args
.
use_checkpoint_args
or
args_defaults
.
get
(
"
use_checkpoint_args
"
,
False
):
assert
args
.
load
is
not
None
,
'
--use-checkpoints-args requires --load argument
'
assert
args
.
load
is
not
None
,
"
--use-checkpoints-args requires --load argument
"
load_args_from_checkpoint
(
args
)
load_args_from_checkpoint
(
args
)
validate_args
(
args
,
args_defaults
)
validate_args
(
args
,
args_defaults
)
# set global args, build tokenizer, and set adlr-autoresume,
# set global args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
# tensorboard-writer, and timers.
set_global_variables
(
args
)
set_global_variables
(
args
)
...
@@ -54,16 +58,16 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -54,16 +58,16 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args
=
get_args
()
args
=
get_args
()
# Pytorch distributed.
# Pytorch distributed.
_initialize_distributed
()
_initialize_distributed
()
# Random seeds for reproducibility.
# Random seeds for reproducibility.
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
'
> setting random seeds to {} ...
'
.
format
(
args
.
seed
))
print
(
"
> setting random seeds to {} ...
"
.
format
(
args
.
seed
))
_set_random_seed
(
args
.
seed
,
args
.
data_parallel_random_init
)
_set_random_seed
(
args
.
seed
,
args
.
data_parallel_random_init
)
args
=
get_args
()
args
=
get_args
()
if
args
.
lazy_mpu_init
:
if
args
.
lazy_mpu_init
:
# TODO is this still a necessary option?
# TODO is this still a necessary option?
args
.
use_cpu_initialization
=
True
args
.
use_cpu_initialization
=
True
# delayed initialization of DDP-related stuff
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
# We only set basic DDP globals
mpu
.
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
mpu
.
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
...
@@ -95,11 +99,15 @@ def _compile_dependencies():
...
@@ -95,11 +99,15 @@ def _compile_dependencies():
# TODO: move this to ninja
# TODO: move this to ninja
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
start_time
=
time
.
time
()
print
(
'
> compiling dataset index builder ...
'
)
print
(
"
> compiling dataset index builder ...
"
)
from
megatron.data.dataset_utils
import
compile_helper
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
compile_helper
()
print
(
'>>> done with dataset index builder. Compilation time: {:.3f} '
print
(
'seconds'
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
)
">>> done with dataset index builder. Compilation time: {:.3f} "
"seconds"
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
,
)
# ==================
# ==================
# Load fused kernels
# Load fused kernels
...
@@ -107,41 +115,51 @@ def _compile_dependencies():
...
@@ -107,41 +115,51 @@ def _compile_dependencies():
# Custom kernel constraints check.
# Custom kernel constraints check.
seq_len
=
args
.
seq_length
seq_len
=
args
.
seq_length
attn_batch_size
=
\
attn_batch_size
=
(
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
\
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
args
.
micro_batch_size
)
*
args
.
micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
4096
and
\
custom_kernel_constraint
=
(
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
seq_len
>
16
and
seq_len
<=
16384
and
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
)
# Print a warning.
# Print a warning.
if
not
((
args
.
fp16
or
args
.
bf16
)
and
if
not
(
custom_kernel_constraint
and
(
args
.
fp16
or
args
.
bf16
)
args
.
masked_softmax_fusion
):
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
):
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
'WARNING: constraints for invoking optimized'
print
(
' fused softmax kernel are not met. We default'
"WARNING: constraints for invoking optimized"
' back to unfused kernel invocations.'
,
flush
=
True
)
" fused softmax kernel are not met. We default"
" back to unfused kernel invocations."
,
flush
=
True
,
)
# Always build on rank zero first.
# Always build on rank zero first.
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
start_time
=
time
.
time
()
print
(
'
> compiling and loading fused kernels ...
'
,
flush
=
True
)
print
(
"
> compiling and loading fused kernels ...
"
,
flush
=
True
)
fused_kernels
.
load
(
args
)
#
fused_kernels.load(args)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
else
:
else
:
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
fused_kernels
.
load
(
args
)
#
fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# rest of the program. We think this might ensure that
# the lock is released.
# the lock is released.
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'>>> done with compiling and loading fused kernels. '
print
(
'Compilation time: {:.3f} seconds'
.
format
(
">>> done with compiling and loading fused kernels. "
time
.
time
()
-
start_time
),
flush
=
True
)
"Compilation time: {:.3f} seconds"
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
,
)
def
_initialize_distributed
():
def
_initialize_distributed
():
...
@@ -152,45 +170,58 @@ def _initialize_distributed():
...
@@ -152,45 +170,58 @@ def _initialize_distributed():
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
is_initialized
():
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
'torch distributed is already initialized, '
print
(
'skipping initialization ...'
,
flush
=
True
)
"torch distributed is already initialized, "
args
.
rank
=
torch
.
distributed
.
get_rank
()
"skipping initialization ..."
,
args
.
world_size
=
torch
.
distributed
.
get_world_size
()
flush
=
True
,
)
#args.rank = torch.distributed.get_rank()
#args.world_size = torch.distributed.get_world_size()
else
:
else
:
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
'
> initializing torch distributed ...
'
,
flush
=
True
)
print
(
"
> initializing torch distributed ...
"
,
flush
=
True
)
# Manually set the device ids.
# Manually set the device ids.
if
device_count
>
0
:
if
device_count
>
0
:
device
=
args
.
rank
%
device_count
device
=
args
.
rank
%
device_count
if
args
.
local_rank
is
not
None
:
if
args
.
local_rank
is
not
None
:
assert
args
.
local_rank
==
device
,
\
assert
(
'expected local-rank to be the same as rank % device-count.'
args
.
local_rank
==
device
),
"expected local-rank to be the same as rank % device-count."
else
:
else
:
args
.
local_rank
=
device
args
.
local_rank
=
device
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
# Call the init process
# Call the init process
torch
.
distributed
.
init_process_group
(
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
world_size
=
args
.
world_size
,
timeout
=
timedelta
(
minutes
=
args
.
distributed_timeout_minutes
))
rank
=
args
.
rank
,
init_method
=
args
.
dist_url
,
timeout
=
timedelta
(
minutes
=
args
.
distributed_timeout_minutes
),
)
# Set the tensor model-parallel, pipeline model-parallel, and
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
# data-parallel communicators.
if
device_count
>
0
:
if
device_count
>
0
:
if
mpu
.
model_parallel_is_initialized
():
if
mpu
.
model_parallel_is_initialized
():
print
(
'
model parallel is already initialized
'
)
print
(
"
model parallel is already initialized
"
)
else
:
else
:
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
mpu
.
initialize_model_parallel
(
args
.
pipeline_model_parallel_size
,
args
.
tensor_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_split_rank
)
args
.
virtual_pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_split_rank
,
args
.
fp8
is
not
None
,
)
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
f
'> initialized tensor model parallel with size '
print
(
f
'
{
mpu
.
get_tensor_model_parallel_world_size
()
}
'
)
f
"> initialized tensor model parallel with size "
print
(
f
'> initialized pipeline model parallel with size '
f
"
{
mpu
.
get_tensor_model_parallel_world_size
()
}
"
f
'
{
mpu
.
get_pipeline_model_parallel_world_size
()
}
'
)
)
print
(
f
"> initialized pipeline model parallel with size "
f
"
{
mpu
.
get_pipeline_model_parallel_world_size
()
}
"
)
def
_init_autoresume
():
def
_init_autoresume
():
...
@@ -216,7 +247,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
...
@@ -216,7 +247,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
if
torch
.
cuda
.
device_count
()
>
0
:
if
torch
.
cuda
.
device_count
()
>
0
:
tensor_parallel
.
model_parallel_cuda_manual_seed
(
seed
)
tensor_parallel
.
model_parallel_cuda_manual_seed
(
seed
)
else
:
else
:
raise
ValueError
(
'
Seed ({}) should be a positive integer.
'
.
format
(
seed
))
raise
ValueError
(
"
Seed ({}) should be a positive integer.
"
.
format
(
seed
))
def
write_args_to_tensorboard
():
def
write_args_to_tensorboard
():
...
@@ -225,15 +256,14 @@ def write_args_to_tensorboard():
...
@@ -225,15 +256,14 @@ def write_args_to_tensorboard():
writer
=
get_tensorboard_writer
()
writer
=
get_tensorboard_writer
()
if
writer
:
if
writer
:
for
arg
in
vars
(
args
):
for
arg
in
vars
(
args
):
writer
.
add_text
(
arg
,
str
(
getattr
(
args
,
arg
)),
writer
.
add_text
(
arg
,
str
(
getattr
(
args
,
arg
)),
global_step
=
args
.
iteration
)
global_step
=
args
.
iteration
)
def
set_jit_fusion_options
():
def
set_jit_fusion_options
():
"""Set PyTorch JIT layer fusion options."""
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
# flags required to enable jit fusion kernels
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
10
):
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
10
):
# nvfuser
# nvfuser
torch
.
_C
.
_jit_set_profiling_executor
(
True
)
torch
.
_C
.
_jit_set_profiling_executor
(
True
)
...
@@ -241,7 +271,7 @@ def set_jit_fusion_options():
...
@@ -241,7 +271,7 @@ def set_jit_fusion_options():
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
False
)
torch
.
_C
.
_jit_set_texpr_fuser_enabled
(
False
)
torch
.
_C
.
_jit_set_texpr_fuser_enabled
(
False
)
torch
.
_C
.
_jit_set_nvfuser_enabled
(
Tru
e
)
torch
.
_C
.
_jit_set_nvfuser_enabled
(
Fals
e
)
torch
.
_C
.
_debug_set_autodiff_subgraph_inlining
(
False
)
torch
.
_C
.
_debug_set_autodiff_subgraph_inlining
(
False
)
else
:
else
:
# legacy pytorch fuser
# legacy pytorch fuser
...
@@ -254,7 +284,7 @@ def set_jit_fusion_options():
...
@@ -254,7 +284,7 @@ def set_jit_fusion_options():
def
_warmup_jit_function
():
def
_warmup_jit_function
():
"""
Compilie JIT functions before the main training steps
"""
"""Compilie JIT functions before the main training steps"""
args
=
get_args
()
args
=
get_args
()
if
args
.
bf16
:
if
args
.
bf16
:
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
...
@@ -264,11 +294,20 @@ def _warmup_jit_function():
...
@@ -264,11 +294,20 @@ def _warmup_jit_function():
dtype
=
torch
.
float32
dtype
=
torch
.
float32
# Warmup fused bias+gelu
# Warmup fused bias+gelu
bias
=
torch
.
rand
(
args
.
ffn_hidden_size
//
args
.
tensor_model_parallel_size
,
bias
=
torch
.
rand
(
dtype
=
dtype
,
device
=
'cuda'
)
args
.
ffn_hidden_size
//
args
.
tensor_model_parallel_size
,
input
=
torch
.
rand
((
args
.
seq_length
,
args
.
micro_batch_size
,
dtype
=
dtype
,
args
.
ffn_hidden_size
//
args
.
tensor_model_parallel_size
),
device
=
"cuda"
,
dtype
=
dtype
,
device
=
'cuda'
)
)
input
=
torch
.
rand
(
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
ffn_hidden_size
//
args
.
tensor_model_parallel_size
,
),
dtype
=
dtype
,
device
=
"cuda"
,
)
# Warmup JIT fusions with the input grad_enable state of both forward
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
# prop and recomputation
for
bias_grad
,
input_grad
in
zip
([
True
,
True
],
[
False
,
True
]):
for
bias_grad
,
input_grad
in
zip
([
True
,
True
],
[
False
,
True
]):
...
@@ -282,15 +321,25 @@ def _warmup_jit_function():
...
@@ -282,15 +321,25 @@ def _warmup_jit_function():
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
else
:
seq_length
=
args
.
seq_length
seq_length
=
args
.
seq_length
input
=
torch
.
rand
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
input
=
torch
.
rand
(
dtype
=
dtype
,
device
=
'cuda'
)
(
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
residual
=
torch
.
rand
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
'cuda'
)
device
=
"cuda"
,
bias
=
torch
.
rand
((
args
.
hidden_size
),
dtype
=
dtype
,
device
=
'cuda'
).
expand_as
(
residual
)
)
residual
=
torch
.
rand
(
(
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
)
bias
=
torch
.
rand
((
args
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
).
expand_as
(
residual
)
dropout_rate
=
0.1
dropout_rate
=
0.1
# Warmup JIT fusions with the input grad_enable state of both forward
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
# prop and recomputation
for
input_grad
,
bias_grad
,
residual_grad
in
zip
([
False
,
True
],
[
True
,
True
],
[
True
,
True
]):
for
input_grad
,
bias_grad
,
residual_grad
in
zip
(
[
False
,
True
],
[
True
,
True
],
[
True
,
True
]
):
input
.
requires_grad
=
input_grad
input
.
requires_grad
=
input_grad
bias
.
requires_grad
=
bias_grad
bias
.
requires_grad
=
bias_grad
residual
.
requires_grad
=
residual_grad
residual
.
requires_grad
=
residual_grad
...
...
megatron/model/bert_model.py
View file @
051f58f1
...
@@ -47,31 +47,27 @@ class BertLMHead(MegatronModule):
...
@@ -47,31 +47,27 @@ class BertLMHead(MegatronModule):
"""Masked LM head for Bert
"""Masked LM head for Bert
Arguments:
Arguments:
config: TransformerConfig object
mpu_vocab_size: model parallel size of vocabulary.
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: whether output logits being distributed or not.
parallel_output: whether output logits being distributed or not.
"""
"""
def
__init__
(
self
,
mpu_vocab_size
,
hidden_size
,
init_method
,
def
__init__
(
self
,
mpu_vocab_size
,
hidden_size
,
config
,
parallel_output
):
layernorm_epsilon
,
parallel_output
):
super
().
__init__
(
config
=
config
)
super
(
BertLMHead
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
tensor_parallel
.
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
1
)
tensor_parallel
.
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
1
)
self
.
parallel_output
=
parallel_output
self
.
parallel_output
=
parallel_output
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
config
.
init_method
)
setattr
(
self
.
dense
.
weight
,
'sequence_parallel'
,
args
.
sequence_parallel
)
setattr
(
self
.
dense
.
weight
,
'sequence_parallel'
,
config
.
sequence_parallel
)
setattr
(
self
.
dense
.
bias
,
'sequence_parallel'
,
args
.
sequence_parallel
)
setattr
(
self
.
dense
.
bias
,
'sequence_parallel'
,
config
.
sequence_parallel
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
,
eps
=
config
.
layernorm_epsilon
,
sequence_parallel
=
args
.
sequence_parallel
)
sequence_parallel
=
config
.
sequence_parallel
)
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
if
args
.
openai_gelu
:
if
args
.
openai_gelu
:
self
.
gelu
=
openai_gelu
self
.
gelu
=
openai_gelu
...
@@ -124,12 +120,13 @@ class BertModel(MegatronModule):
...
@@ -124,12 +120,13 @@ class BertModel(MegatronModule):
"""Bert Language model."""
"""Bert Language model."""
def
__init__
(
self
,
def
__init__
(
self
,
config
,
num_tokentypes
=
2
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
add_binary_head
=
True
,
parallel_output
=
True
,
parallel_output
=
True
,
pre_process
=
True
,
pre_process
=
True
,
post_process
=
True
):
post_process
=
True
):
super
(
BertModel
,
self
).
__init__
(
)
super
(
).
__init__
(
config
=
config
)
args
=
get_args
()
args
=
get_args
()
# TODO this option is not yet implemented in BERT
# TODO this option is not yet implemented in BERT
...
@@ -145,29 +142,23 @@ class BertModel(MegatronModule):
...
@@ -145,29 +142,23 @@ class BertModel(MegatronModule):
if
self
.
return_embeddings
:
if
self
.
return_embeddings
:
assert
self
.
post_process
and
self
.
add_binary_head
assert
self
.
post_process
and
self
.
add_binary_head
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
self
.
add_binary_head
,
add_pooler
=
self
.
add_binary_head
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
,
pre_process
=
self
.
pre_process
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
post_process
=
self
.
post_process
)
self
.
initialize_word_embeddings
(
init_method_normal
)
self
.
initialize_word_embeddings
()
if
self
.
post_process
:
if
self
.
post_process
:
self
.
lm_head
=
BertLMHead
(
self
.
lm_head
=
BertLMHead
(
self
.
shared_embedding_or_output_weight
().
size
(
0
),
config
.
hidden_size
,
self
.
word_embeddings_weight
().
size
(
0
),
config
,
parallel_output
)
args
.
hidden_size
,
init_method
,
args
.
layernorm_epsilon
,
parallel_output
)
self
.
_lm_head_key
=
'lm_head'
self
.
_lm_head_key
=
'lm_head'
self
.
binary_head
=
None
self
.
binary_head
=
None
if
self
.
add_binary_head
:
if
self
.
add_binary_head
:
self
.
binary_head
=
get_linear_layer
(
args
.
hidden_size
,
2
,
self
.
binary_head
=
get_linear_layer
(
config
.
hidden_size
,
2
,
init_method
)
config
.
init_method
)
self
.
_binary_head_key
=
'binary_head'
self
.
_binary_head_key
=
'binary_head'
def
set_input_tensor
(
self
,
input_tensor
):
def
set_input_tensor
(
self
,
input_tensor
):
...
@@ -215,7 +206,7 @@ class BertModel(MegatronModule):
...
@@ -215,7 +206,7 @@ class BertModel(MegatronModule):
return
post_language_model_processing
(
lm_output
,
pooled_output
,
return
post_language_model_processing
(
lm_output
,
pooled_output
,
self
.
lm_head
,
self
.
binary_head
,
self
.
lm_head
,
self
.
binary_head
,
lm_labels
,
lm_labels
,
self
.
wor
d_embedding
s
_weight
(),
self
.
share
d_embedding
_or_output
_weight
(),
self
.
fp16_lm_cross_entropy
)
self
.
fp16_lm_cross_entropy
)
else
:
else
:
return
lm_output
return
lm_output
...
...
megatron/model/classification.py
View file @
051f58f1
...
@@ -17,25 +17,23 @@ from .module import MegatronModule
...
@@ -17,25 +17,23 @@ from .module import MegatronModule
class
Classification
(
MegatronModule
):
class
Classification
(
MegatronModule
):
def
__init__
(
self
,
def
__init__
(
self
,
config
,
num_classes
,
num_classes
,
num_tokentypes
=
2
,
num_tokentypes
=
2
,
pre_process
=
True
,
pre_process
=
True
,
post_process
=
True
):
post_process
=
True
):
super
(
Classification
,
self
).
__init__
(
share_
word_
embeddings
=
False
)
super
(
).
__init__
(
config
=
config
,
share_embeddings
_and_output_weights
=
False
)
args
=
get_args
()
args
=
get_args
()
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
post_process
=
post_process
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
post_process
=
self
.
post_process
)
...
...
megatron/model/distributed.py
View file @
051f58f1
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
from
abc
import
ABC
from
abc
import
ABC
from
abc
import
abstractmethod
from
abc
import
abstractmethod
...
@@ -73,7 +73,7 @@ class DistributedDataParallelBase(MegatronModule, ABC):
...
@@ -73,7 +73,7 @@ class DistributedDataParallelBase(MegatronModule, ABC):
class
DistributedDataParallel
(
DistributedDataParallelBase
):
class
DistributedDataParallel
(
DistributedDataParallelBase
):
"""DDP with contiguous buffers options to stor
r
e and accumulate gradients.
"""DDP with contiguous buffers options to store and accumulate gradients.
This class:
This class:
- has the potential to reduce memory fragmentation.
- has the potential to reduce memory fragmentation.
- provides the option to do the gradient accumulation
- provides the option to do the gradient accumulation
...
...
megatron/model/enums.py
View file @
051f58f1
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
import
enum
import
enum
class
LayerType
(
enum
.
Enum
):
class
LayerType
(
enum
.
Enum
):
encoder
=
1
encoder
=
1
decoder
=
2
decoder
=
2
retro_encoder
=
3
retro_decoder
=
4
retro_decoder_with_retriever
=
5
class
AttnType
(
enum
.
Enum
):
class
AttnType
(
enum
.
Enum
):
self_attn
=
1
self_attn
=
1
...
...
megatron/model/fused_softmax.py
View file @
051f58f1
...
@@ -155,12 +155,12 @@ class FusedScaleMaskSoftmax(nn.Module):
...
@@ -155,12 +155,12 @@ class FusedScaleMaskSoftmax(nn.Module):
if
(
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
self
.
input_in_float16
# input must be fp16
and
16
<
sk
<=
4096
# sk must be 16 ~
2048
and
16
<
sk
<=
16384
# sk must be 16 ~
16384
and
sq
%
4
==
0
# sq must be divisor of 4
and
sq
%
4
==
0
# sq must be divisor of 4
and
sk
%
4
==
0
# sk must be divisor of 4
and
sk
%
4
==
0
# sk must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
):
if
0
<=
sk
<=
4096
:
if
0
<=
sk
<=
16384
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
...
...
megatron/model/gpt_model.py
View file @
051f58f1
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
"""GPT-2 model."""
"""GPT-2 model."""
...
@@ -11,8 +11,6 @@ from .module import MegatronModule
...
@@ -11,8 +11,6 @@ from .module import MegatronModule
from
.enums
import
AttnMaskType
from
.enums
import
AttnMaskType
from
.language_model
import
parallel_lm_logits
from
.language_model
import
parallel_lm_logits
from
.language_model
import
get_language_model
from
.language_model
import
get_language_model
from
.utils
import
init_method_normal
from
.utils
import
scaled_init_method_normal
def
post_language_model_processing
(
lm_output
,
labels
,
logit_weights
,
def
post_language_model_processing
(
lm_output
,
labels
,
logit_weights
,
...
@@ -46,12 +44,13 @@ class GPTModel(MegatronModule):
...
@@ -46,12 +44,13 @@ class GPTModel(MegatronModule):
"""GPT-2 Language model."""
"""GPT-2 Language model."""
def
__init__
(
self
,
def
__init__
(
self
,
config
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
parallel_output
=
True
,
parallel_output
=
True
,
pre_process
=
True
,
pre_process
=
True
,
post_process
=
True
):
post_process
=
True
):
args
=
get_args
()
args
=
get_args
()
super
(
GPTModel
,
self
).
__init__
(
share_word_embedding
s
=
not
args
.
untie_embeddings_and_output_weights
)
super
(
).
__init__
(
config
=
config
,
share_embeddings_and_output_weight
s
=
not
args
.
untie_embeddings_and_output_weights
)
self
.
parallel_output
=
parallel_output
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
...
@@ -60,39 +59,39 @@ class GPTModel(MegatronModule):
...
@@ -60,39 +59,39 @@ class GPTModel(MegatronModule):
self
.
untie_embeddings_and_output_weights
=
args
.
untie_embeddings_and_output_weights
self
.
untie_embeddings_and_output_weights
=
args
.
untie_embeddings_and_output_weights
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
add_pooler
=
False
,
encoder_attn_mask_type
=
AttnMaskType
.
causal
,
encoder_attn_mask_type
=
AttnMaskType
.
causal
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
post_process
=
self
.
post_process
)
if
not
args
.
untie_embeddings_and_output_weights
:
if
not
args
.
untie_embeddings_and_output_weights
:
self
.
initialize_word_embeddings
(
init_method_normal
)
self
.
initialize_word_embeddings
()
def
set_input_tensor
(
self
,
input_tensor
):
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
"""See megatron.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
ret_input_ids
=
None
,
ret_position_ids
=
None
,
ret_attn_mask
=
None
,
retriever_input_ids
=
None
,
retriever_position_ids
=
None
,
retriever_attn_mask
=
None
,
labels
=
None
,
tokentype_ids
=
None
,
inference_params
=
None
):
labels
=
None
,
tokentype_ids
=
None
,
inference_params
=
None
):
lm_output
=
self
.
language_model
(
lm_output
=
self
.
language_model
(
input_ids
,
input_ids
,
position_ids
,
position_ids
,
attention_mask
,
attention_mask
,
ret_input_ids
=
ret_input_ids
,
ret
riever
_input_ids
=
ret
riever
_input_ids
,
ret_position_ids
=
ret_position_ids
,
ret
riever
_position_ids
=
ret
riever
_position_ids
,
ret_attn_mask
=
ret_attn_mask
,
ret
riever
_attn_mask
=
ret
riever
_attn_mask
,
inference_params
=
inference_params
)
inference_params
=
inference_params
)
if
self
.
post_process
:
if
self
.
post_process
:
return
post_language_model_processing
(
return
post_language_model_processing
(
lm_output
,
labels
,
lm_output
,
labels
,
self
.
language_model
.
output_layer
.
weight
if
self
.
untie_embeddings_and_output_weights
else
self
.
wor
d_embedding
s
_weight
(),
self
.
language_model
.
output_layer
.
weight
if
self
.
untie_embeddings_and_output_weights
else
self
.
share
d_embedding
_or_output
_weight
(),
self
.
parallel_output
,
self
.
parallel_output
,
self
.
fp16_lm_cross_entropy
)
self
.
fp16_lm_cross_entropy
)
else
:
else
:
...
...
megatron/model/language_model.py
View file @
051f58f1
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
"""Transformer based language model."""
"""Transformer based language model."""
...
@@ -7,11 +7,11 @@ import torch.nn.functional as F
...
@@ -7,11 +7,11 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core.enums
import
ModelType
from
megatron.core.models.common.rotary_pos_embedding
import
RotaryEmbedding
from
.enums
import
LayerType
,
AttnMaskType
from
.enums
import
AttnMaskType
,
LayerType
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
.retro_transformer
import
ParallelRetroEncoder
,
ParallelRetroTransformer
from
.rotary_pos_embedding
import
apply_rotary_pos_emb
,
RotaryEmbedding
from
.transformer
import
ParallelTransformer
from
.transformer
import
ParallelTransformer
from
.utils
import
get_linear_layer
from
.utils
import
get_linear_layer
from
.utils
import
init_method_normal
,
scaled_init_method_normal
from
.utils
import
init_method_normal
,
scaled_init_method_normal
...
@@ -39,7 +39,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -39,7 +39,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias
=
bias
,
bias
=
bias
,
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
,
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
,
async_grad_allreduce
=
async_grad_allreduce
,
async_grad_allreduce
=
async_grad_allreduce
,
sequence_parallel
_enabled
=
args
.
sequence_parallel
)
sequence_parallel
=
args
.
sequence_parallel
)
# Gather if needed.
# Gather if needed.
if
parallel_output
:
if
parallel_output
:
...
@@ -48,26 +48,24 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -48,26 +48,24 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
return
tensor_parallel
.
gather_from_tensor_model_parallel_region
(
logits_parallel
)
return
tensor_parallel
.
gather_from_tensor_model_parallel_region
(
logits_parallel
)
def
get_language_model
(
num_tokentypes
,
add_pooler
,
def
get_language_model
(
config
,
num_tokentypes
,
add_pooler
,
encoder_attn_mask_type
,
init_method
=
None
,
encoder_attn_mask_type
,
scaled_init_method
=
None
,
add_encoder
=
True
,
add_encoder
=
True
,
add_decoder
=
False
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
pre_process
=
True
,
post_process
=
True
):
pre_process
=
True
,
post_process
=
True
):
"""Build language model and return along with the key to save."""
"""Build language model and return along with the key to save."""
args
=
get_args
()
args
=
get_args
()
if
config
.
init_method
is
None
:
config
.
init_method
=
init_method_normal
(
config
.
init_method_std
)
if
init_method
is
None
:
if
config
.
output_layer_init_method
is
None
:
init_method
=
init_method_normal
(
args
.
init_method_std
)
config
.
output_layer_init_method
=
scaled_init_method_normal
(
config
.
init_method_std
,
config
.
num_layers
)
if
scaled_init_method
is
None
:
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
# Language model.
# Language model.
language_model
=
TransformerLanguageModel
(
language_model
=
TransformerLanguageModel
(
init_method
,
config
,
scaled_init_method
,
encoder_attn_mask_type
,
encoder_attn_mask_type
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_encoder
=
add_encoder
,
add_encoder
=
add_encoder
,
...
@@ -131,6 +129,10 @@ class Embedding(MegatronModule):
...
@@ -131,6 +129,10 @@ class Embedding(MegatronModule):
init_method: weight initialization method
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
will ignore this embedding
embedding_weights_in_fp32: casts word embedding weights to
fp32 before sampling. Required to
maintain reproducibility when
training in bf16.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -138,28 +140,26 @@ class Embedding(MegatronModule):
...
@@ -138,28 +140,26 @@ class Embedding(MegatronModule):
vocab_size
,
vocab_size
,
max_sequence_length
,
max_sequence_length
,
embedding_dropout_prob
,
embedding_dropout_prob
,
init_method
,
config
,
num_tokentypes
=
0
):
num_tokentypes
=
0
,
embedding_weights_in_fp32
=
False
):
super
(
Embedding
,
self
).
__init__
()
super
(
Embedding
,
self
).
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
init_method
=
init_method
self
.
init_method
=
config
.
init_method
self
.
num_tokentypes
=
num_tokentypes
self
.
num_tokentypes
=
num_tokentypes
args
=
get_args
()
args
=
get_args
()
# Word embeddings (parallel).
# Word embeddings (parallel).
self
.
embedding_weights_in_fp32
=
embedding_weights_in_fp32
self
.
params_dtype
=
args
.
params_dtype
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
vocab_size
,
self
.
hidden_size
,
vocab_size
,
self
.
hidden_size
,
config
=
config
,
init_method
=
config
.
init_method
)
init_method
=
self
.
init_method
,
params_dtype
=
args
.
params_dtype
,
use_cpu_initialization
=
args
.
use_cpu_initialization
,
perform_initialization
=
args
.
perform_initialization
)
self
.
_word_embeddings_key
=
'word_embeddings'
self
.
_word_embeddings_key
=
'word_embeddings'
# Position embedding (serial).
# Position embedding (serial).
self
.
add_position_embedding
=
args
.
add_
position_embedding
self
.
add_position_embedding
=
args
.
position_embedding
_type
==
'learned_absolute'
if
self
.
add_position_embedding
:
if
self
.
add_position_embedding
:
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
max_sequence_length
,
self
.
hidden_size
)
max_sequence_length
,
self
.
hidden_size
)
...
@@ -182,7 +182,7 @@ class Embedding(MegatronModule):
...
@@ -182,7 +182,7 @@ class Embedding(MegatronModule):
else
:
else
:
self
.
tokentype_embeddings
=
None
self
.
tokentype_embeddings
=
None
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
sequence_parallel
=
args
.
sequence_parallel
self
.
sequence_parallel
=
args
.
sequence_parallel
# Embeddings dropout
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
...
@@ -217,7 +217,12 @@ class Embedding(MegatronModule):
...
@@ -217,7 +217,12 @@ class Embedding(MegatronModule):
def
forward
(
self
,
input_ids
,
position_ids
,
tokentype_ids
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
,
tokentype_ids
=
None
):
# Embeddings.
# Embeddings.
if
self
.
embedding_weights_in_fp32
:
self
.
word_embeddings
=
self
.
word_embeddings
.
to
(
torch
.
float32
)
words_embeddings
=
self
.
word_embeddings
(
input_ids
)
words_embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
embedding_weights_in_fp32
:
words_embeddings
=
words_embeddings
.
to
(
self
.
params_dtype
)
self
.
word_embeddings
=
self
.
word_embeddings
.
to
(
self
.
params_dtype
)
if
self
.
add_position_embedding
:
if
self
.
add_position_embedding
:
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
words_embeddings
+
position_embeddings
embeddings
=
words_embeddings
+
position_embeddings
...
@@ -326,8 +331,7 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -326,8 +331,7 @@ class TransformerLanguageModel(MegatronModule):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
init_method
,
config
,
output_layer_init_method
,
encoder_attn_mask_type
,
encoder_attn_mask_type
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
add_encoder
=
True
,
add_encoder
=
True
,
...
@@ -337,21 +341,22 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -337,21 +341,22 @@ class TransformerLanguageModel(MegatronModule):
pre_process
=
True
,
pre_process
=
True
,
post_process
=
True
):
post_process
=
True
):
args
=
get_args
()
args
=
get_args
()
# TODO: passing share_
word_
embeddings=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
# TODO: passing share_embeddings
_and_output_weights
=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
if
args
.
untie_embeddings_and_output_weights
:
assert
not
add_decoder
if
args
.
untie_embeddings_and_output_weights
:
assert
not
add_decoder
super
(
TransformerLanguageModel
,
self
).
__init__
(
share_
word_
embeddings
=
not
args
.
untie_embeddings_and_output_weights
)
super
(
TransformerLanguageModel
,
self
).
__init__
(
share_embeddings
_and_output_weights
=
not
args
.
untie_embeddings_and_output_weights
)
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
post_process
=
post_process
self
.
hidden_size
=
args
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
num_tokentypes
=
num_tokentypes
self
.
num_tokentypes
=
num_tokentypes
self
.
init_method
=
init_method
self
.
init_method
=
config
.
init_method
self
.
add_encoder
=
add_encoder
self
.
add_encoder
=
add_encoder
self
.
encoder_attn_mask_type
=
encoder_attn_mask_type
self
.
encoder_attn_mask_type
=
encoder_attn_mask_type
self
.
add_decoder
=
add_decoder
self
.
add_decoder
=
add_decoder
self
.
decoder_attn_mask_type
=
decoder_attn_mask_type
self
.
decoder_attn_mask_type
=
decoder_attn_mask_type
self
.
add_pooler
=
add_pooler
self
.
add_pooler
=
add_pooler
self
.
encoder_hidden_state
=
None
self
.
encoder_hidden_state
=
None
self
.
add_retriever
=
args
.
retro_add_retriever
self
.
untie_embeddings_and_output_weights
=
args
.
untie_embeddings_and_output_weights
self
.
untie_embeddings_and_output_weights
=
args
.
untie_embeddings_and_output_weights
# Embeddings.
# Embeddings.
...
@@ -360,14 +365,15 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -360,14 +365,15 @@ class TransformerLanguageModel(MegatronModule):
args
.
padded_vocab_size
,
args
.
padded_vocab_size
,
args
.
max_position_embeddings
,
args
.
max_position_embeddings
,
args
.
hidden_dropout
,
args
.
hidden_dropout
,
self
.
init_method
,
config
,
self
.
num_tokentypes
)
self
.
num_tokentypes
,
args
.
embedding_weights_in_fp32
)
self
.
_embedding_key
=
'embedding'
self
.
_embedding_key
=
'embedding'
# Rotary positional embeddings
# Rotary positional embeddings
self
.
use_rotary_position_embeddings
=
\
self
.
use_rotary_position_embeddings
=
\
args
.
use_rotary_
position_embedding
s
args
.
position_embedding
_type
==
'rope'
if
args
.
use_rotary_position_embeddings
:
if
self
.
use_rotary_position_embeddings
:
self
.
seq_length
=
args
.
seq_length
self
.
seq_length
=
args
.
seq_length
rotary_dim
=
args
.
hidden_size
//
args
.
num_attention_heads
\
rotary_dim
=
args
.
hidden_size
//
args
.
num_attention_heads
\
if
args
.
kv_channels
is
None
else
args
.
kv_channels
if
args
.
kv_channels
is
None
else
args
.
kv_channels
...
@@ -378,41 +384,22 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -378,41 +384,22 @@ class TransformerLanguageModel(MegatronModule):
# partial rotary embeddings, which is better than full rotary
# partial rotary embeddings, which is better than full rotary
# Wang and Komatsuzaki et al
# Wang and Komatsuzaki et al
# https://github.com/kingoflolz/mesh-transformer-jax/
# https://github.com/kingoflolz/mesh-transformer-jax/
self
.
rotary_pos_emb
=
RotaryEmbedding
(
rotary_dim
)
self
.
rotary_pos_emb
=
RotaryEmbedding
(
rotary_dim
,
# Retriever (bi-directional transformer with cross attention)
seq_len_interpolation_factor
=
args
.
rotary_seq_len_interpolation_factor
if
args
.
retro_add_retriever
:
self
.
retriever
=
ParallelRetroEncoder
(
self
.
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
self
.
pre_process
,
post_process
=
False
,
)
)
self
.
_retriever_key
=
'retriever'
else
:
self
.
retriever
=
None
# Encoder (usually set to True, False if part of an encoder-decoder
# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
# architecture and in encoder-only stage).
if
self
.
add_encoder
:
if
self
.
add_encoder
:
if
args
.
retro_add_retriever
:
self
.
encoder
=
ParallelTransformer
(
self
.
encoder
=
ParallelRetroTransformer
(
config
,
self
.
init_method
,
model_type
=
args
.
model_type
if
not
args
.
retro_add_retriever
\
output_layer_init_method
,
else
ModelType
.
retro_decoder
,
self_attn_mask_type
=
self
.
encoder_attn_mask_type
,
self_attn_mask_type
=
self
.
encoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
post_process
=
self
.
post_process
,
retriever
=
self
.
retriever
,
)
)
else
:
self
.
encoder
=
ParallelTransformer
(
self
.
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
self
.
encoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
)
self
.
_encoder_key
=
'encoder'
self
.
_encoder_key
=
'encoder'
else
:
else
:
self
.
encoder
=
None
self
.
encoder
=
None
...
@@ -421,8 +408,8 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -421,8 +408,8 @@ class TransformerLanguageModel(MegatronModule):
# architecture and in decoder-only stage).
# architecture and in decoder-only stage).
if
self
.
add_decoder
:
if
self
.
add_decoder
:
self
.
decoder
=
ParallelTransformer
(
self
.
decoder
=
ParallelTransformer
(
self
.
init_method
,
config
,
output_layer_init_method
,
model_type
=
args
.
model_type
,
layer_type
=
LayerType
.
decoder
,
layer_type
=
LayerType
.
decoder
,
self_attn_mask_type
=
self
.
decoder_attn_mask_type
,
self_attn_mask_type
=
self
.
decoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
pre_process
=
self
.
pre_process
,
...
@@ -441,8 +428,9 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -441,8 +428,9 @@ class TransformerLanguageModel(MegatronModule):
self
.
output_layer
=
tensor_parallel
.
ColumnParallelLinear
(
self
.
output_layer
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
hidden_size
,
args
.
padded_vocab_size
,
args
.
padded_vocab_size
,
bias
=
False
,
# Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
config
=
config
,
init_method
=
self
.
init_method
)
init_method
=
self
.
init_method
,
bias
=
False
)
# Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
self
.
_output_layer_key
=
'output_layer'
self
.
_output_layer_key
=
'output_layer'
def
set_input_tensor
(
self
,
input_tensor
):
def
set_input_tensor
(
self
,
input_tensor
):
...
@@ -475,19 +463,14 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -475,19 +463,14 @@ class TransformerLanguageModel(MegatronModule):
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
ret_input_ids
=
None
,
ret_position_ids
=
None
,
ret_attn_mask
=
None
,
retriever_input_ids
=
None
,
retriever_position_ids
=
None
,
retriever_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
inference_params
=
None
,
inference_params
=
None
,
pooling_sequence_index
=
0
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
# Retriever embedding.
if
self
.
retriever
and
self
.
pre_process
:
retriever_input
=
self
.
embedding
(
ret_input_ids
,
ret_position_ids
,
tokentype_ids
=
tokentype_ids
)
else
:
retriever_input
=
None
# Encoder embedding.
# Encoder embedding.
if
self
.
pre_process
:
if
self
.
pre_process
:
encoder_input
=
self
.
embedding
(
enc_input_ids
,
enc_position_ids
,
encoder_input
=
self
.
embedding
(
enc_input_ids
,
enc_position_ids
,
...
@@ -495,31 +478,33 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -495,31 +478,33 @@ class TransformerLanguageModel(MegatronModule):
else
:
else
:
encoder_input
=
None
encoder_input
=
None
# Retriever embedding.
if
self
.
add_retriever
and
self
.
pre_process
:
retriever_input
=
self
.
embedding
(
retriever_input_ids
,
retriever_position_ids
,
tokentype_ids
=
tokentype_ids
)
else
:
retriever_input
=
None
# Rotary positional embeddings
# Rotary positional embeddings
rotary_pos_emb
=
None
rotary_pos_emb
=
None
if
self
.
use_rotary_position_embeddings
:
if
self
.
use_rotary_position_embeddings
:
if
inference_params
is
not
None
:
if
inference_params
is
not
None
:
rotary_pos_emb
=
\
rotary_pos_emb
=
\
self
.
rotary_pos_emb
(
inference_params
.
max_sequence_len
)
self
.
rotary_pos_emb
(
inference_params
.
max_sequence_len
gth
)
else
:
else
:
rotary_pos_emb
=
self
.
rotary_pos_emb
(
self
.
seq_length
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
self
.
seq_length
)
# Run encoder.
# Run encoder.
if
enc_hidden_states
is
None
:
if
enc_hidden_states
is
None
:
if
self
.
encoder
is
not
None
:
if
self
.
encoder
is
not
None
:
if
self
.
retriever
:
encoder_output
=
self
.
encoder
(
encoder_output
=
self
.
encoder
(
encoder_input
,
encoder_input
,
enc_attn_mask
,
enc_attn_mask
,
retriever_input
=
retriever_input
,
retriever_output
=
retriever_input
,
retriever_attn_mask
=
retriever_attn_mask
,
retriever_attn_mask
=
ret_attn_mask
,
inference_params
=
inference_params
,
inference_params
=
inference_params
)
rotary_pos_emb
=
rotary_pos_emb
)
else
:
encoder_output
=
self
.
encoder
(
encoder_input
,
enc_attn_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
)
else
:
else
:
encoder_output
=
self
.
encoder_hidden_state
encoder_output
=
self
.
encoder_hidden_state
else
:
else
:
...
...
megatron/model/module.py
View file @
051f58f1
...
@@ -25,9 +25,10 @@ class MegatronModule(torch.nn.Module):
...
@@ -25,9 +25,10 @@ class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support
"""Megatron specific extensions of torch Module with support
for pipelining."""
for pipelining."""
def
__init__
(
self
,
share_word_embedding
s
=
True
):
def
__init__
(
self
,
config
=
None
,
share_embeddings_and_output_weight
s
=
True
):
super
(
MegatronModule
,
self
).
__init__
()
super
(
MegatronModule
,
self
).
__init__
()
self
.
share_word_embeddings
=
share_word_embeddings
self
.
config
=
config
self
.
share_embeddings_and_output_weights
=
share_embeddings_and_output_weights
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
...
@@ -36,21 +37,21 @@ class MegatronModule(torch.nn.Module):
...
@@ -36,21 +37,21 @@ class MegatronModule(torch.nn.Module):
return
self
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
self
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
def
wor
d_embedding
s
_weight
(
self
):
def
share
d_embedding
_or_output
_weight
(
self
):
if
self
.
pre_process
:
if
self
.
pre_process
:
return
self
.
language_model
.
embedding
.
word_embeddings
.
weight
return
self
.
language_model
.
embedding
.
word_embeddings
.
weight
else
:
else
:
if
not
self
.
share_
word_
embeddings
:
if
not
self
.
share_embeddings
_and_output_weights
:
raise
Exception
(
'
wor
d_embedding
s
_weight() called for last '
raise
Exception
(
'
share
d_embedding
_or_output
_weight() called for last '
'stage, but share_
word_
embeddings is false'
)
'stage, but share_embeddings
_and_output_weights
is false'
)
return
self
.
word_embeddings
.
weight
return
self
.
word_embeddings
.
weight
def
initialize_word_embeddings
(
self
,
init_method_normal
):
def
initialize_word_embeddings
(
self
):
args
=
get_args
()
args
=
get_args
()
if
not
self
.
share_
word_
embeddings
:
if
not
self
.
share_embeddings
_and_output_weights
:
raise
Exception
(
'initialize_word_embeddings() was called but '
raise
Exception
(
'initialize_word_embeddings() was called but '
'share_
word_
embeddings is false'
)
'share_embeddings
_and_output_weights
is false'
)
# This function just initializes the word embeddings in the final stage
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. Nothing to do if we aren't
# when we are using pipeline parallelism. Nothing to do if we aren't
...
@@ -76,11 +77,8 @@ class MegatronModule(torch.nn.Module):
...
@@ -76,11 +77,8 @@ class MegatronModule(torch.nn.Module):
# set word_embeddings weights to 0 here, then copy first
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
# stage's weights using all_reduce below.
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
args
.
padded_vocab_size
,
args
.
hidden_size
,
args
.
padded_vocab_size
,
self
.
config
.
hidden_size
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
)
params_dtype
=
args
.
params_dtype
,
use_cpu_initialization
=
args
.
use_cpu_initialization
,
perform_initialization
=
args
.
perform_initialization
)
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
self
.
word_embeddings
.
weight
.
shared
=
True
...
@@ -103,7 +101,7 @@ class MegatronModule(torch.nn.Module):
...
@@ -103,7 +101,7 @@ class MegatronModule(torch.nn.Module):
# Ensure that first and last stages have the same initial parameter
# Ensure that first and last stages have the same initial parameter
# values.
# values.
if
mpu
.
is_rank_in_embedding_group
():
if
mpu
.
is_rank_in_embedding_group
():
torch
.
distributed
.
all_reduce
(
self
.
wor
d_embedding
s
_weight
().
data
,
torch
.
distributed
.
all_reduce
(
self
.
share
d_embedding
_or_output
_weight
().
data
,
group
=
mpu
.
get_embedding_group
())
group
=
mpu
.
get_embedding_group
())
# Ensure that encoder(first stage) and decoder(split stage) position
# Ensure that encoder(first stage) and decoder(split stage) position
...
...
megatron/model/multiple_choice.py
View file @
051f58f1
...
@@ -17,23 +17,21 @@ from .module import MegatronModule
...
@@ -17,23 +17,21 @@ from .module import MegatronModule
class
MultipleChoice
(
MegatronModule
):
class
MultipleChoice
(
MegatronModule
):
def
__init__
(
self
,
def
__init__
(
self
,
config
,
num_tokentypes
=
2
,
num_tokentypes
=
2
,
pre_process
=
True
,
pre_process
=
True
,
post_process
=
True
):
post_process
=
True
):
super
(
MultipleChoice
,
self
).
__init__
(
share_
word_
embeddings
=
False
)
super
(
MultipleChoice
,
self
).
__init__
(
share_embeddings
_and_output_weights
=
False
)
args
=
get_args
()
args
=
get_args
()
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
post_process
=
post_process
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
True
,
add_pooler
=
True
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
post_process
=
self
.
post_process
)
...
...
megatron/model/retro_transformer.py
deleted
100644 → 0
View file @
0024a5c6
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Retro Transformer.
** Special note about this file **
Many classes and methods in this file directly parallel those in transformer.py
in name and utility. However, due to 1) subtle changes in the code over time
(i.e., transposes and contexts), and 2) other code that is soon to be merged,
this file will *temporarily* remain as is, until a larger integration is
complete.
"""
import
math
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
,
get_retro_args
,
get_tensorboard_writer
from
megatron.core
import
parallel_state
from
megatron.core
import
tensor_parallel
from
megatron.core
import
utils
as
core_utils
from
megatron.core.enums
import
ModelType
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model
import
LayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
,
init_method_normal
from
.module
import
MegatronModule
from
.transformer
import
_get_num_layers
,
ParallelMLP
,
NoopTransformerLayer
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
p: number of model parallel partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
"""
class
DropPath
(
MegatronModule
):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
*Note: differs from transformer.py/DropPath in hidden_state transpose.
"""
def
__init__
(
self
,
drop_prob
=
0.
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
hidden_state
):
if
self
.
drop_prob
==
0.
or
not
self
.
training
:
return
hidden_state
keep_prob
=
1
-
self
.
drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape
=
(
hidden_state
.
shape
[
0
],)
+
(
1
,)
*
(
hidden_state
.
ndim
-
1
)
random_tensor
=
keep_prob
+
\
torch
.
rand
(
shape
,
dtype
=
hidden_state
.
dtype
,
device
=
hidden_state
.
device
)
random_tensor
.
floor_
()
# binarize
output
=
hidden_state
.
div
(
keep_prob
)
*
random_tensor
return
output
class
SwitchMLP
(
MegatronModule
):
"""
Routes input to one of N MLP "experts"
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
super
(
SwitchMLP
,
self
).
__init__
()
args
=
get_args
()
self
.
router
=
torch
.
nn
.
Linear
(
args
.
hidden_size
,
args
.
num_experts
)
self
.
experts
=
torch
.
nn
.
ModuleList
()
for
i
in
range
(
args
.
num_experts
):
self
.
experts
.
append
(
ParallelMLP
(
init_method
,
output_layer_init_method
))
def
forward
(
self
,
hidden_states
):
# hidden_states: [b, s, h]
b
=
hidden_states
.
size
(
0
)
s
=
hidden_states
.
size
(
1
)
h
=
hidden_states
.
size
(
2
)
route
=
self
.
router
(
hidden_states
)
route
=
torch
.
nn
.
functional
.
softmax
(
route
,
dim
=
2
)
max_prob
,
max_ind
=
torch
.
max
(
route
,
dim
=
2
)
max_prob
=
torch
.
unsqueeze
(
max_prob
,
2
)
# [b s 1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [b, s, h] to [b*s, h].
# Each vector could be routed differently
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
size
(
2
))
# [b*s h]
max_prob
=
max_prob
.
view
(
-
1
,
max_prob
.
size
(
2
))
# [b*s 1]
max_ind
=
max_ind
.
view
(
-
1
)
# [b*s]
output_total
=
torch
.
empty_like
(
hidden_states
)
output_bias_total
=
torch
.
empty_like
(
hidden_states
)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for
expert_num
,
expert
in
enumerate
(
self
.
experts
):
local_indices
=
(
max_ind
==
expert_num
).
nonzero
()
hidden
=
hidden_states
[
local_indices
,:]
output
,
output_bias
=
expert
(
hidden
)
output_bias
=
output_bias
.
expand_as
(
output
)
output_total
[
local_indices
,:]
=
output
output_bias_total
[
local_indices
,:]
=
output_bias
output_total
=
output_total
*
max_prob
output_bias_total
=
output_bias_total
*
max_prob
output_total
=
output_total
.
view
(
b
,
s
,
h
)
output_bias_total
=
output_bias_total
.
view
(
b
,
s
,
h
)
return
output_total
,
output_bias_total
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
AttnMaskType
.
padding
):
super
(
ParallelAttention
,
self
).
__init__
()
args
=
get_args
()
self
.
fp16
=
args
.
fp16
self
.
bf16
=
args
.
bf16
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
self
.
params_dtype
=
args
.
params_dtype
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
# Per attention head and per partition values.
world_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_partition
=
core_utils
.
divide
(
projection_size
,
world_size
)
self
.
hidden_size_per_attention_head
=
core_utils
.
divide
(
projection_size
,
args
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
core_utils
.
divide
(
args
.
num_attention_heads
,
world_size
)
# Strided linear layer.
if
attention_type
==
AttnType
.
self_attn
:
self
.
query_key_value
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
3
*
projection_size
,
gather_output
=
False
,
init_method
=
init_method
)
else
:
assert
attention_type
==
AttnType
.
cross_attn
self
.
query
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
projection_size
,
gather_output
=
False
,
init_method
=
init_method
)
self
.
key_value
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
2
*
projection_size
,
gather_output
=
False
,
init_method
=
init_method
)
coeff
=
None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
if
self
.
apply_query_key_layer_scaling
:
coeff
=
self
.
layer_number
self
.
norm_factor
*=
coeff
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
self
.
bf16
,
self
.
attn_mask_type
,
args
.
masked_softmax_fusion
,
attention_mask_func
,
self
.
attention_softmax_in_fp32
,
coeff
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self
.
attention_dropout
=
torch
.
nn
.
Dropout
(
args
.
attention_dropout
)
# Output.
self
.
dense
=
tensor_parallel
.
RowParallelLinear
(
projection_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
):
return
torch
.
empty
(
inference_max_sequence_len
,
batch_size
,
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
dtype
=
self
.
params_dtype
,
device
=
torch
.
cuda
.
current_device
())
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
inference_params
=
None
):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if
inference_params
:
if
self
.
layer_number
not
in
inference_params
.
key_value_memory_dict
:
inf_max_seq_len
=
inference_params
.
max_sequence_len
inf_max_batch_size
=
inference_params
.
max_batch_size
inference_key_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
)
inference_value_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
)
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_key_memory
,
inference_value_memory
)
else
:
inference_key_memory
,
inference_value_memory
=
\
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
# =====================
# Query, Key, and Value
# =====================
if
self
.
attention_type
==
AttnType
.
self_attn
:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
3
*
self
.
hidden_size_per_attention_head
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(
query_layer
,
key_layer
,
value_layer
)
=
tensor_parallel
\
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
else
:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer
,
_
=
self
.
key_value
(
encoder_output
)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape
=
mixed_kv_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
2
*
self
.
hidden_size_per_attention_head
)
mixed_kv_layer
=
mixed_kv_layer
.
view
(
*
new_tensor_shape
)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(
key_layer
,
value_layer
)
=
tensor_parallel
\
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer
,
_
=
self
.
query
(
hidden_states
)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape
=
query_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ==================================
# Adjust key and value for inference
# ==================================
if
inference_params
:
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
key_layer
.
size
(
1
)
assert
batch_end
<=
inference_key_memory
.
size
(
1
)
sequence_start
=
inference_params
.
sequence_len_offset
sequence_end
=
sequence_start
+
key_layer
.
size
(
0
)
assert
sequence_end
<=
inference_key_memory
.
size
(
0
)
# Copy key and values.
inference_key_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
key_layer
inference_value_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
value_layer
key_layer
=
inference_key_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
value_layer
=
inference_value_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting result tensor: [b * np, sq, sk]
matmul_result
=
torch
.
empty
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
],
dtype
=
query_layer
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs
=
self
.
scale_mask_softmax
(
attention_scores
,
attention_mask
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
))
# change view [sk, b * np, hn]
value_layer
=
value_layer
.
view
(
value_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# change view [b * np, sq, sk]
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
# matmul: [b * np, sq, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
# change view [b, np, sq, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
\
(
self
.
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
# =================
# Output. [sq, b, h]
# =================
output
,
bias
=
self
.
dense
(
context_layer
)
return
output
,
bias
def
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out
=
torch
.
nn
.
functional
.
dropout
(
x
+
bias
,
p
=
prob
,
training
=
training
)
out
=
residual
+
out
return
out
def
get_bias_dropout_add
(
training
):
def
_bias_dropout_add
(
x
,
bias
,
residual
,
prob
):
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
)
return
_bias_dropout_add
@
torch
.
jit
.
script
def
bias_dropout_add_fused_train
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
True
)
@
torch
.
jit
.
script
def
bias_dropout_add_fused_inference
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
False
)
class
ParallelRetroTransformerEncoderLayer
(
MegatronModule
):
"""A single transformer layer for Retro Decoder with an retriever encoder inside and cross attention.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
drop_path_rate
=
0.
,
retriever
=
None
):
args
=
get_args
()
super
(
ParallelRetroTransformerEncoderLayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Retro Encoder
self
.
retriever
=
retriever
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# Self attention.
self
.
self_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
\
else
None
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
self
.
inter_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
cross_attn
)
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# MLP
if
args
.
num_experts
is
not
None
:
self
.
mlp
=
SwitchMLP
(
init_method
,
output_layer_init_method
)
else
:
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
retriever_output
,
retriever_attn_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
attention_mask
,
inference_params
=
inference_params
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
hidden_states
if
self
.
drop_path
is
None
:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
attention_output
+
attention_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
layernorm_input
=
residual
+
self
.
drop_path
(
out
)
# Layer norm post the self attention. # [ns, bs, d]
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
"""
notations:
l: number of chunks
m: number of token per chunk
bs: batch size
d: hidden size
k: number of neighbors
r: number of tokens per neighbors (neighbors + continuation)
"""
args
=
get_args
()
retro_args
=
get_retro_args
()
chunk_length
=
retro_args
.
retro_gpt_chunk_length
retrieved_length
=
retro_args
.
retro_gpt_retrieved_length
num_neighbors
=
args
.
retro_num_neighbors
ns
,
bs
,
d
=
layernorm_output
.
shape
l
=
int
(
np
.
ceil
(
ns
/
chunk_length
))
first_ns
=
ns
%
chunk_length
if
first_ns
>
0
:
first_chunk
,
rest_chunk
=
\
layernorm_output
[:
first_ns
],
layernorm_output
[
first_ns
:]
first_chunk
=
torch
.
nn
.
functional
.
pad
(
first_chunk
,
(
0
,
0
,
0
,
0
,
0
,
chunk_length
-
first_ns
),
'constant'
,
0
)
chunked_output
=
\
torch
.
cat
((
first_chunk
,
rest_chunk
),
dim
=
0
)
# [l * m, bs, d]
else
:
chunked_output
=
layernorm_output
# [l * m, bs, d]
chunked_output
=
chunked_output
\
.
reshape
(
l
,
chunk_length
,
bs
,
d
)
\
.
permute
(
1
,
2
,
0
,
3
)
\
.
reshape
(
chunk_length
,
bs
*
l
,
d
)
\
.
contiguous
()
# Get Encoder Output
retriever_output
=
self
.
retriever
(
retriever_output
,
retriever_attn_mask
,
retriever_output
=
chunked_output
,
retriever_attn_mask
=
retriever_attn_mask
,
inference_params
=
inference_params
)
# [r, k * bs * l , d]
retriever_output
=
retriever_output
.
reshape
(
retrieved_length
*
num_neighbors
,
bs
*
l
,
d
)
# [r * k, bs * l, d]
# Chunked Cross attention with Retriever Encoder
pad
=
(
ns
-
1
)
%
chunk_length
attending_chunks
=
layernorm_output
[
pad
:]
# [ns - m + 1, bs, d]
padded_chunks
=
torch
.
nn
.
functional
.
pad
(
attending_chunks
,
(
0
,
0
,
0
,
0
,
0
,
chunk_length
-
1
),
'constant'
,
0
)
# [ns, bs, d]
padded_chunked_output
=
padded_chunks
\
.
reshape
(
l
,
chunk_length
,
bs
,
d
)
\
.
permute
(
1
,
2
,
0
,
3
)
padded_chunked_output
=
padded_chunked_output
.
reshape
(
chunk_length
,
bs
*
l
,
d
).
contiguous
()
# [m, bs * l, d]
# attention_output: [m, bs * l, d]
# attention_bias: [d]
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
padded_chunked_output
,
# Q: main model embedding
None
,
encoder_output
=
retriever_output
)
# KV: retriever output embedding
# Residual connection
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
attention_output
),
torch
.
zeros_like
(
attention_output
),
self
.
hidden_dropout
)
layernorm_input
=
layernorm_input
\
.
reshape
(
chunk_length
,
bs
,
l
,
d
)
\
.
permute
(
2
,
0
,
1
,
3
)
# [l, m, bs, d]
layernorm_input
=
layernorm_input
.
reshape
(
chunk_length
*
l
,
bs
,
d
)
layernorm_input
=
torch
.
nn
.
functional
.
pad
(
layernorm_input
,
(
0
,
0
,
0
,
0
,
pad
,
0
),
'constant'
,
0
)[:
ns
]
# [ns, b, d]
layernorm_input
=
layernorm_input
+
residual
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
if
self
.
drop_path
is
None
:
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
+
mlp_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
return
output
,
retriever_output
class
ParallelRetroTransformerLayer
(
MegatronModule
):
"""A single transformer layer for Retro Decoder with cross attention.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
drop_path_rate
=
0.
):
args
=
get_args
()
super
(
ParallelRetroTransformerLayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# Self attention.
self
.
self_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
\
else
None
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
self
.
inter_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
cross_attn
)
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# MLP
if
args
.
num_experts
is
not
None
:
self
.
mlp
=
SwitchMLP
(
init_method
,
output_layer_init_method
)
else
:
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
retriever_output
,
retriever_attn_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
attention_mask
,
inference_params
=
inference_params
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
hidden_states
if
self
.
drop_path
is
None
:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
attention_output
+
attention_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
layernorm_input
=
residual
+
self
.
drop_path
(
out
)
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
args
=
get_args
()
retro_args
=
get_retro_args
()
chunk_length
=
retro_args
.
retro_gpt_chunk_length
ns
,
bs
,
d
=
layernorm_output
.
shape
l
=
int
(
np
.
ceil
(
ns
/
chunk_length
))
pad
=
(
ns
-
1
)
%
chunk_length
attending_chunks
=
layernorm_output
[
pad
:]
padded_chunks
=
torch
.
nn
.
functional
.
pad
(
attending_chunks
,
(
0
,
0
,
0
,
0
,
0
,
chunk_length
-
1
),
'constant'
,
0
)
padded_chunked_output
=
padded_chunks
\
.
reshape
(
l
,
chunk_length
,
bs
,
d
)
\
.
permute
(
1
,
2
,
0
,
3
)
padded_chunked_output
=
padded_chunked_output
.
reshape
(
chunk_length
,
bs
*
l
,
d
).
contiguous
()
# Encoder output.
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
padded_chunked_output
,
None
,
encoder_output
=
retriever_output
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
attention_output
),
torch
.
zeros_like
(
attention_output
),
self
.
hidden_dropout
)
layernorm_input
=
layernorm_input
\
.
reshape
(
chunk_length
,
bs
,
l
,
d
)
\
.
permute
(
2
,
0
,
1
,
3
)
# [l, m, bs, d]
layernorm_input
=
layernorm_input
.
reshape
(
chunk_length
*
l
,
bs
,
d
)
layernorm_input
=
torch
.
nn
.
functional
.
pad
(
layernorm_input
,
(
0
,
0
,
0
,
0
,
pad
,
0
),
'constant'
,
0
)[:
ns
]
layernorm_input
=
layernorm_input
+
residual
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
if
self
.
drop_path
is
None
:
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
+
mlp_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
return
output
class
ParallelRetroEncoderTransformerCALayer
(
MegatronModule
):
"""A single transformer layer for Retro Encoder with cross attention.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
drop_path_rate
=
0.
):
args
=
get_args
()
super
(
ParallelRetroEncoderTransformerCALayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# Self attention.
self
.
self_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
)
self
.
self_attention
.
attention_dropout
=
\
torch
.
nn
.
Dropout
(
args
.
retro_encoder_attention_dropout
)
self
.
hidden_dropout
=
args
.
retro_encoder_hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
\
else
None
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
self
.
inter_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
cross_attn
)
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# MLP
if
args
.
num_experts
is
not
None
:
self
.
mlp
=
SwitchMLP
(
init_method
,
output_layer_init_method
)
else
:
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
retriever_output
,
retriever_attn_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
attention_mask
,
inference_params
=
inference_params
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
hidden_states
if
self
.
drop_path
is
None
:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
attention_output
+
attention_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
layernorm_input
=
residual
+
self
.
drop_path
(
out
)
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
# Neighbors.
args
=
get_args
()
retro_args
=
get_retro_args
()
retrieved_length
=
retro_args
.
retro_gpt_retrieved_length
num_neighbors
=
args
.
retro_num_neighbors
ns
,
bs
,
d
=
layernorm_output
.
shape
# [r, bs * l * k, d]
chunked_outputs
=
layernorm_output
.
reshape
(
retrieved_length
,
-
1
,
num_neighbors
,
d
)
chunked_outputs_before_layer_norm
=
\
layernorm_input
.
reshape
(
retrieved_length
,
-
1
,
num_neighbors
,
d
)
# [r, bs * l, k, d]
layernorm_inputs
=
[]
layernorm_outputs
=
[]
for
k
in
range
(
num_neighbors
):
chunked_output
=
chunked_outputs
[:,:,
k
].
contiguous
()
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
chunked_output
,
# Q (neighbor embedding)
None
,
encoder_output
=
retriever_output
)
# K, V (hidden act)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
chunked_output
else
:
residual
=
chunked_outputs_before_layer_norm
[:,:,
k
]
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
layernorm_inputs
.
append
(
layernorm_input
)
# Layer norm post the decoder attention
layernorm_output
=
\
self
.
post_inter_attention_layernorm
(
layernorm_input
)
layernorm_outputs
.
append
(
layernorm_output
)
# layernorm_input : [r, k * bs * l, d]
# layernorm_output : [r, k * bs * l, d]
layernorm_input
=
\
torch
.
stack
(
layernorm_inputs
,
dim
=
1
).
reshape
(
ns
,
bs
,
d
)
layernorm_output
=
\
torch
.
stack
(
layernorm_outputs
,
dim
=
1
).
reshape
(
ns
,
bs
,
d
)
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
if
self
.
drop_path
is
None
:
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
+
mlp_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
return
output
class
ParallelTransformerLayer
(
MegatronModule
):
"""A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
drop_path_rate
=
0.
):
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# Self attention.
self
.
self_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
\
else
None
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
init_method
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
cross_attn
)
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# MLP
if
args
.
num_experts
is
not
None
:
self
.
mlp
=
SwitchMLP
(
init_method
,
output_layer_init_method
)
else
:
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
attention_mask
,
inference_params
=
inference_params
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
hidden_states
if
self
.
drop_path
is
None
:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
attention_output
+
attention_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
layernorm_input
=
residual
+
self
.
drop_path
(
out
)
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
if
self
.
layer_type
==
LayerType
.
decoder
:
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
layernorm_output
,
enc_dec_attn_mask
,
encoder_output
=
encoder_output
)
# residual connection
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
if
self
.
drop_path
is
None
:
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
+
mlp_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
return
output
class
ParallelRetroEncoder
(
MegatronModule
):
""" Retro Transformer class for encoder ."""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
True
,
post_process
=
True
,
drop_path_rate
=
0.0
):
super
(
ParallelRetroEncoder
,
self
).
__init__
()
args
=
get_args
()
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
self
.
drop_path_rate
=
drop_path_rate
# Store activation checkpoiting flag.
self
.
recompute_granularity
=
args
.
recompute_granularity
self
.
recompute_method
=
args
.
recompute_method
self
.
recompute_num_layers
=
args
.
recompute_num_layers
self
.
distribute_saved_activations
=
\
args
.
distribute_saved_activations
and
not
args
.
sequence_parallel
self
.
sequence_parallel
=
args
.
sequence_parallel
# Number of layers.
self
.
num_layers
=
args
.
retro_encoder_layers
self
.
drop_path_rates
=
[
rate
.
item
()
for
rate
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
args
.
num_layers
)]
if
args
.
retro_add_retriever
:
self
.
P
=
[
1
]
# Transformer layers.
assert
args
.
retro_add_retriever
def
build_layer
(
layer_number
):
if
layer_number
in
self
.
P
:
return
ParallelRetroEncoderTransformerCALayer
(
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
else
:
layer
=
ParallelTransformerLayer
(
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
layer
.
self_attention
.
attention_dropout
=
\
torch
.
nn
.
Dropout
(
args
.
retro_encoder_attention_dropout
)
layer
.
hidden_dropout
=
args
.
retro_encoder_hidden_dropout
return
layer
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'virtual_pipeline_model_parallel_size'
assert
args
.
model_type
!=
ModelType
.
encoder_and_decoder
# Number of layers in each model chunk is the number of layers in
# the stage, divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an
# assignment of layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an
# assignment of layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset
=
parallel_state
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
(
parallel_state
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
# Each stage gets a contiguous set of layers.
if
args
.
model_type
==
ModelType
.
encoder_and_decoder
and
\
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
:
pipeline_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
if
layer_type
==
LayerType
.
encoder
:
offset
=
pipeline_rank
*
self
.
num_layers
else
:
num_ranks_in_enc
=
args
.
pipeline_model_parallel_split_rank
offset
=
(
pipeline_rank
-
num_ranks_in_enc
)
*
self
.
num_layers
else
:
offset
=
parallel_state
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
if
self
.
num_layers
==
0
:
# When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self
.
num_layers
=
1
self
.
layers
=
torch
.
nn
.
ModuleList
([
NoopTransformerLayer
(
1
)
])
else
:
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
if
self
.
post_process
:
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
):
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
):
def
custom_forward
(
*
inputs
):
x_
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
encoder_output
=
inputs
[
2
]
enc_dec_attn_mask
=
inputs
[
3
]
for
index
in
range
(
start
,
end
):
layer
=
self
.
_get_layer
(
index
)
x_
=
layer
(
x_
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
return
x_
return
custom_forward
if
self
.
activations_checkpoint_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and
# checkpoint the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l
=
0
while
l
<
self
.
num_layers
:
hidden_states
=
parallel_state
.
checkpoint
(
custom
(
l
,
l
+
self
.
activations_checkpoint_num_layers
),
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
activations_checkpoint_num_layers
elif
self
.
activations_checkpoint_method
==
'block'
:
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
activations_checkpoint_num_layers
:
hidden_states
=
parallel_state
.
checkpoint
(
custom
(
l
,
l
+
1
),
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
raise
ValueError
(
"Invalid activation checkpoint method."
)
return
hidden_states
def
set_input_tensor
(
self
,
input_tensor
):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
hidden_states
,
attention_mask
,
retriever_output
,
retriever_attn_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# Checks.
if
inference_params
:
assert
self
.
activations_checkpoint_method
is
None
,
\
'inference does not work with activation checkpointing'
if
self
.
pre_process
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
fp32_residual_connection
:
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
().
float
()
# Otherwise, leave it as is.
else
:
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
else
:
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states
=
core_utils
.
make_viewless_tensor
(
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
,
)
# Transpose encoder output.
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
args
=
get_args
()
assert
not
args
.
sequence_parallel
,
"if SP, need rng context."
# Forward pass.
if
self
.
recompute_granularity
==
'full'
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
if
index
+
1
in
self
.
P
:
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
retriever_output
=
retriever_output
,
retriever_attn_mask
=
retriever_attn_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
else
:
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
# Final layer norm.
if
self
.
post_process
:
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
self
.
final_layernorm
(
hidden_states
)
else
:
output
=
hidden_states
return
output
class
ParallelRetroTransformer
(
MegatronModule
):
"""Standard GPT Transformer class."""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
True
,
post_process
=
True
,
drop_path_rate
=
0.0
,
retriever
=
None
):
super
(
ParallelRetroTransformer
,
self
).
__init__
()
args
=
get_args
()
assert
pre_process
and
post_process
,
\
"pipeline parallelism un-supported."
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
self
.
drop_path_rate
=
drop_path_rate
# Store activation checkpoiting flag.
self
.
recompute_granularity
=
args
.
recompute_granularity
self
.
recompute_method
=
args
.
recompute_method
self
.
recompute_num_layers
=
args
.
recompute_num_layers
self
.
distribute_saved_activations
=
\
args
.
distribute_saved_activations
and
not
args
.
sequence_parallel
self
.
sequence_parallel
=
args
.
sequence_parallel
# Number of layers.
self
.
num_layers
=
_get_num_layers
(
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
self
.
drop_path_rates
=
[
rate
.
item
()
for
rate
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
args
.
num_layers
)]
if
args
.
retro_add_retriever
:
if
args
.
num_layers
==
12
:
self
.
P
=
[
6
,
9
,
12
]
elif
args
.
num_layers
==
24
:
self
.
P
=
np
.
arange
(
9
,
25
,
3
).
tolist
()
elif
args
.
num_layers
==
40
:
self
.
P
=
np
.
arange
(
9
,
41
,
3
).
tolist
()
self
.
P
.
append
(
40
)
self
.
retriever
=
retriever
# Transformer layers.
assert
args
.
retro_add_retriever
def
build_layer
(
layer_number
):
if
layer_number
==
min
(
self
.
P
):
return
ParallelRetroTransformerEncoderLayer
(
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
],
retriever
=
retriever
)
elif
layer_number
in
self
.
P
:
return
ParallelRetroTransformerLayer
(
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
else
:
return
ParallelTransformerLayer
(
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'virtual_pipeline_model_parallel_size'
assert
args
.
model_type
!=
ModelType
.
encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset
=
parallel_state
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
(
parallel_state
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
# Each stage gets a contiguous set of layers.
if
args
.
model_type
==
ModelType
.
encoder_and_decoder
and
\
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
:
pipeline_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
if
layer_type
==
LayerType
.
encoder
:
offset
=
pipeline_rank
*
self
.
num_layers
else
:
num_ranks_in_enc
=
args
.
pipeline_model_parallel_split_rank
offset
=
(
pipeline_rank
-
num_ranks_in_enc
)
*
self
.
num_layers
else
:
offset
=
parallel_state
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
if
self
.
num_layers
==
0
:
# When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self
.
num_layers
=
1
self
.
layers
=
torch
.
nn
.
ModuleList
([
NoopTransformerLayer
(
1
)
])
else
:
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
if
self
.
post_process
:
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
):
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
):
def
custom_forward
(
*
inputs
):
x_
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
encoder_output
=
inputs
[
2
]
enc_dec_attn_mask
=
inputs
[
3
]
for
index
in
range
(
start
,
end
):
layer
=
self
.
_get_layer
(
index
)
x_
=
layer
(
x_
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
return
x_
return
custom_forward
if
self
.
activations_checkpoint_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l
=
0
while
l
<
self
.
num_layers
:
hidden_states
=
parallel_state
.
checkpoint
(
custom
(
l
,
l
+
self
.
activations_checkpoint_num_layers
),
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
activations_checkpoint_num_layers
elif
self
.
activations_checkpoint_method
==
'block'
:
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
activations_checkpoint_num_layers
:
hidden_states
=
parallel_state
.
checkpoint
(
custom
(
l
,
l
+
1
),
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
raise
ValueError
(
"Invalid activation checkpoint method."
)
return
hidden_states
def
set_input_tensor
(
self
,
input_tensor
):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
hidden_states
,
attention_mask
,
retriever_output
=
None
,
retriever_attn_mask
=
None
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# Checks.
if
inference_params
:
assert
self
.
recompute_granularity
is
None
,
\
'inference does not work with activation checkpointing'
args
=
get_args
()
# Transpose retriever output, to match hidden_states shape.
retriever_output
=
retriever_output
.
transpose
(
0
,
1
).
contiguous
()
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states
=
core_utils
.
make_viewless_tensor
(
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
,
)
# Transpose encoder output.
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
# Forward pass.
assert
not
args
.
sequence_parallel
,
"if SP, need rng context."
if
self
.
recompute_granularity
==
'full'
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
if
args
.
retro_add_retriever
and
index
+
1
==
min
(
self
.
P
):
hidden_states
,
E
=
layer
(
hidden_states
,
attention_mask
,
retriever_output
=
retriever_output
,
retriever_attn_mask
=
retriever_attn_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
elif
args
.
retro_add_retriever
and
index
+
1
in
self
.
P
:
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
retriever_output
=
E
,
retriever_attn_mask
=
retriever_attn_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
else
:
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
# Final layer norm.
output
=
self
.
final_layernorm
(
hidden_states
)
return
output
megatron/model/t5_model.py
View file @
051f58f1
...
@@ -11,9 +11,7 @@ from megatron.model.language_model import parallel_lm_logits, get_language_model
...
@@ -11,9 +11,7 @@ from megatron.model.language_model import parallel_lm_logits, get_language_model
from
megatron.model
import
LayerNorm
from
megatron.model
import
LayerNorm
from
megatron.model.utils
import
(
from
megatron.model.utils
import
(
openai_gelu
,
openai_gelu
,
get_linear_layer
,
get_linear_layer
init_method_normal
,
scaled_init_method_normal
)
)
from
.module
import
MegatronModule
from
.module
import
MegatronModule
...
@@ -43,17 +41,12 @@ class T5LMHead(MegatronModule):
...
@@ -43,17 +41,12 @@ class T5LMHead(MegatronModule):
Arguments:
Arguments:
mpu_vocab_size: model parallel size of vocabulary.
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: wether output logits being distributed or not.
parallel_output: wether output logits being distributed or not.
"""
"""
def
__init__
(
self
,
mpu_vocab_size
,
parallel_output
):
def
__init__
(
self
,
mpu_vocab_size
,
parallel_output
):
super
(
T5LMHead
,
self
).
__init__
()
super
(
T5LMHead
,
self
).
__init__
()
args
=
get_args
()
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
self
.
bias
.
model_parallel
=
True
self
.
bias
.
model_parallel
=
True
self
.
bias
.
partition_dim
=
0
self
.
bias
.
partition_dim
=
0
...
@@ -72,41 +65,38 @@ class T5Model(MegatronModule):
...
@@ -72,41 +65,38 @@ class T5Model(MegatronModule):
"""T5 Language model."""
"""T5 Language model."""
def
__init__
(
self
,
def
__init__
(
self
,
config
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
parallel_output
=
True
,
parallel_output
=
True
,
pre_process
=
True
,
pre_process
=
True
,
post_process
=
True
,
post_process
=
True
,
add_encoder
=
True
,
add_encoder
=
True
,
add_decoder
=
True
):
add_decoder
=
True
):
super
(
T5Model
,
self
).
__init__
(
)
super
(
).
__init__
(
config
=
config
)
args
=
get_args
()
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
parallel_output
=
parallel_output
self
.
parallel_output
=
parallel_output
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
post_process
=
post_process
self
.
add_encoder
=
add_encoder
self
.
add_encoder
=
add_encoder
self
.
add_decoder
=
add_decoder
self
.
add_decoder
=
add_decoder
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
add_pooler
=
False
,
add_encoder
=
add_encoder
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
,
add_decoder
=
add_decoder
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
,
pre_process
=
self
.
pre_process
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
post_process
=
self
.
post_process
)
self
.
initialize_word_embeddings
(
init_method_normal
)
self
.
initialize_word_embeddings
()
if
self
.
post_process
and
self
.
add_decoder
:
if
self
.
post_process
and
self
.
add_decoder
:
self
.
lm_head
=
T5LMHead
(
self
.
lm_head
=
T5LMHead
(
self
.
wor
d_embedding
s
_weight
().
size
(
0
),
self
.
share
d_embedding
_or_output
_weight
().
size
(
0
),
parallel_output
)
parallel_output
)
self
.
_lm_head_key
=
'lm_head'
self
.
_lm_head_key
=
'lm_head'
...
@@ -139,7 +129,7 @@ class T5Model(MegatronModule):
...
@@ -139,7 +129,7 @@ class T5Model(MegatronModule):
decoder_output
,
encoder_output
=
lm_output
decoder_output
,
encoder_output
=
lm_output
# Output. [s, b, h]
# Output. [s, b, h]
lm_logits
=
self
.
lm_head
(
decoder_output
,
lm_logits
=
self
.
lm_head
(
decoder_output
,
self
.
wor
d_embedding
s
_weight
())
self
.
share
d_embedding
_or_output
_weight
())
if
lm_labels
is
None
:
if
lm_labels
is
None
:
# [s b h] => [b s h]
# [s b h] => [b s h]
...
...
megatron/model/transformer.py
View file @
051f58f1
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
"""Transformer."""
"""Transformer."""
import
math
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
import
math
import
numpy
as
np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
typing
import
Optional
from
typing
import
Optional
from
megatron
import
get_timers
,
get_args
,
core
,
get_num_microbatches
from
megatron
import
get_timers
,
get_args
,
get_retro_args
,
core
,
get_num_microbatches
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core.enums
import
ModelType
from
megatron.core.enums
import
ModelType
...
@@ -15,7 +16,7 @@ from megatron.model import LayerNorm
...
@@ -15,7 +16,7 @@ from megatron.model import LayerNorm
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.
model
.rotary_pos_embedding
import
apply_rotary_pos_emb
from
megatron.
core.models.common
.rotary_pos_embedding
import
apply_rotary_pos_emb
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
try
:
try
:
...
@@ -26,7 +27,10 @@ except ImportError:
...
@@ -26,7 +27,10 @@ except ImportError:
try
:
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_func
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_func
except
ImportError
:
except
ImportError
:
flash_attn_unpadded_func
=
None
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_func
as
flash_attn_unpadded_func
except
ImportError
:
flash_attn_unpadded_func
=
None
""" We use the following notation throughout this file:
""" We use the following notation throughout this file:
h: hidden size
h: hidden size
...
@@ -65,18 +69,6 @@ class DropPath(MegatronModule):
...
@@ -65,18 +69,6 @@ class DropPath(MegatronModule):
output
=
hidden_state
.
div
(
keep_prob
)
*
random_tensor
output
=
hidden_state
.
div
(
keep_prob
)
*
random_tensor
return
output
return
output
def
_args_to_kwargs
():
args
=
get_args
()
common_kwargs
=
{
"params_dtype"
:
args
.
params_dtype
,
"use_cpu_initialization"
:
args
.
use_cpu_initialization
,
"perform_initialization"
:
args
.
perform_initialization
,
"gradient_accumulation_fusion"
:
args
.
gradient_accumulation_fusion
,
"sequence_parallel_enabled"
:
args
.
sequence_parallel
,
}
return
common_kwargs
class
ParallelMLP
(
MegatronModule
):
class
ParallelMLP
(
MegatronModule
):
"""MLP.
"""MLP.
...
@@ -85,22 +77,26 @@ class ParallelMLP(MegatronModule):
...
@@ -85,22 +77,26 @@ class ParallelMLP(MegatronModule):
state back into h hidden dimension.
state back into h hidden dimension.
"""
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
def
__init__
(
self
,
config
):
super
(
ParallelMLP
,
self
).
__init__
()
super
(
ParallelMLP
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
add_bias
=
args
.
add_bias_linear
self
.
add_bias
=
config
.
add_bias_linear
ffn_hidden_size
=
config
.
ffn_hidden_size
if
config
.
gated_linear_unit
:
ffn_hidden_size
*=
2
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self
.
dense_h_to_4h
=
tensor_parallel
.
ColumnParallelLinear
(
self
.
dense_h_to_4h
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
config
.
hidden_size
,
args
.
ffn_hidden_size
*
2
if
args
.
swiglu
else
args
.
ffn_hidden_size
,
ffn_hidden_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
self
.
add_bias
,
bias
=
self
.
add_bias
,
gather_output
=
False
,
gather_output
=
False
,
init_method
=
init_method
,
skip_bias_add
=
True
,
skip_bias_add
=
True
,
async_tensor_model_parallel_allreduce
=
args
.
async_tensor_model_parallel_allreduce
,
)
**
_args_to_kwargs
())
self
.
bias_gelu_fusion
=
False
self
.
bias_gelu_fusion
=
False
self
.
activation_func
=
None
self
.
activation_func
=
None
...
@@ -125,13 +121,13 @@ class ParallelMLP(MegatronModule):
...
@@ -125,13 +121,13 @@ class ParallelMLP(MegatronModule):
# Project back to h.
# Project back to h.
self
.
dense_4h_to_h
=
tensor_parallel
.
RowParallelLinear
(
self
.
dense_4h_to_h
=
tensor_parallel
.
RowParallelLinear
(
args
.
ffn_hidden_size
,
config
.
ffn_hidden_size
,
args
.
hidden_size
,
config
.
hidden_size
,
config
=
config
,
init_method
=
config
.
output_layer_init_method
,
bias
=
self
.
add_bias
,
bias
=
self
.
add_bias
,
input_is_parallel
=
True
,
input_is_parallel
=
True
init_method
=
output_layer_init_method
,
)
skip_bias_add
=
True
,
**
_args_to_kwargs
())
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
...
@@ -155,13 +151,13 @@ class SwitchMLP(MegatronModule):
...
@@ -155,13 +151,13 @@ class SwitchMLP(MegatronModule):
"""
"""
Routes input to one of N MLP "experts"
Routes input to one of N MLP "experts"
"""
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
def
__init__
(
self
,
config
):
super
(
SwitchMLP
,
self
).
__init__
()
super
(
SwitchMLP
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
router
=
torch
.
nn
.
Linear
(
args
.
hidden_size
,
args
.
num_experts
)
self
.
router
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
args
.
num_experts
)
self
.
experts
=
torch
.
nn
.
ModuleList
()
self
.
experts
=
torch
.
nn
.
ModuleList
()
for
i
in
range
(
args
.
num_experts
):
for
i
in
range
(
args
.
num_experts
):
self
.
experts
.
append
(
ParallelMLP
(
init_method
,
output_layer_init_method
))
self
.
experts
.
append
(
ParallelMLP
(
config
))
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
# hidden_states: [s, b, h]
# hidden_states: [s, b, h]
...
@@ -188,45 +184,48 @@ class SwitchMLP(MegatronModule):
...
@@ -188,45 +184,48 @@ class SwitchMLP(MegatronModule):
local_indices
=
(
max_ind
==
expert_num
).
nonzero
()
local_indices
=
(
max_ind
==
expert_num
).
nonzero
()
hidden
=
hidden_states
[
local_indices
,:]
hidden
=
hidden_states
[
local_indices
,:]
output
,
output_bias
=
expert
(
hidden
)
output
,
output_bias
=
expert
(
hidden
)
output_bias
=
output_bias
.
expand_as
(
output
)
if
output_bias
is
not
None
:
output_bias
=
output_bias
.
expand_as
(
output
)
output_bias_total
[
local_indices
,:]
=
output_bias
output_total
[
local_indices
,:]
=
output
output_total
[
local_indices
,:]
=
output
output_bias_total
[
local_indices
,:]
=
output_bias
output_total
=
output_total
*
max_prob
output_total
=
output_total
*
max_prob
output_bias_total
=
output_bias_total
*
max_prob
output_total
=
output_total
.
view
(
s
,
b
,
h
)
output_total
=
output_total
.
view
(
s
,
b
,
h
)
output_bias_total
=
output_bias_total
.
view
(
s
,
b
,
h
)
if
output_bias
is
not
None
:
output_bias_total
=
output_bias_total
*
max_prob
output_bias_total
=
output_bias_total
.
view
(
s
,
b
,
h
)
else
:
output_bias_total
=
None
return
output_total
,
output_bias_total
return
output_total
,
output_bias_total
class
CoreAttention
(
MegatronModule
):
class
CoreAttention
(
MegatronModule
):
def
__init__
(
self
,
layer_number
,
def
__init__
(
self
,
layer_number
,
config
,
attn_mask_type
=
AttnMaskType
.
padding
):
attn_mask_type
=
AttnMaskType
.
padding
):
super
(
CoreAttention
,
self
).
__init__
()
super
(
CoreAttention
,
self
).
__init__
()
args
=
get_args
()
self
.
fp16
=
config
.
fp16
self
.
fp16
=
args
.
fp16
self
.
bf16
=
config
.
bf16
self
.
bf16
=
args
.
bf16
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
self
.
apply_query_key_layer_scaling
=
config
.
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
self
.
attention_softmax_in_fp32
=
config
.
attention_softmax_in_fp32
if
self
.
apply_query_key_layer_scaling
:
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
sequence_parallel
=
args
.
sequence_parallel
self
.
sequence_parallel
=
config
.
sequence_parallel
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
projection_size
=
config
.
kv_channels
*
config
.
num_attention_heads
# Per attention head and per partition values.
# Per attention head and per partition values.
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_partition
=
core
.
utils
.
divide
(
projection_size
,
self
.
hidden_size_per_partition
=
core
.
utils
.
divide
(
projection_size
,
world_size
)
world_size
)
self
.
hidden_size_per_attention_head
=
core
.
utils
.
divide
(
self
.
hidden_size_per_attention_head
=
core
.
utils
.
divide
(
projection_size
,
args
.
num_attention_heads
)
projection_size
,
config
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
core
.
utils
.
divide
(
self
.
num_attention_heads_per_partition
=
core
.
utils
.
divide
(
args
.
num_attention_heads
,
world_size
)
config
.
num_attention_heads
,
world_size
)
coeff
=
None
coeff
=
None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
...
@@ -237,7 +236,7 @@ class CoreAttention(MegatronModule):
...
@@ -237,7 +236,7 @@ class CoreAttention(MegatronModule):
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
self
.
bf16
,
self
.
fp16
,
self
.
bf16
,
self
.
attn_mask_type
,
self
.
attn_mask_type
,
args
.
masked_softmax_fusion
,
config
.
masked_softmax_fusion
,
attention_mask_func
,
attention_mask_func
,
self
.
attention_softmax_in_fp32
,
self
.
attention_softmax_in_fp32
,
coeff
)
coeff
)
...
@@ -245,7 +244,7 @@ class CoreAttention(MegatronModule):
...
@@ -245,7 +244,7 @@ class CoreAttention(MegatronModule):
# Dropout. Note that for a single iteration, this layer will generate
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
# on average it should not be partition dependent.
self
.
attention_dropout
=
torch
.
nn
.
Dropout
(
args
.
attention_dropout
)
self
.
attention_dropout
=
torch
.
nn
.
Dropout
(
config
.
attention_dropout
)
def
forward
(
self
,
query_layer
,
key_layer
,
def
forward
(
self
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
):
value_layer
,
attention_mask
):
...
@@ -261,8 +260,8 @@ class CoreAttention(MegatronModule):
...
@@ -261,8 +260,8 @@ class CoreAttention(MegatronModule):
key_layer
.
size
(
0
))
key_layer
.
size
(
0
))
# [sq, b, np, hn] -> [sq, b * np, hn]
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
query_layer
=
query_layer
.
reshape
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
output_size
[
3
],
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
...
@@ -379,17 +378,18 @@ class FlashSelfAttention(torch.nn.Module):
...
@@ -379,17 +378,18 @@ class FlashSelfAttention(torch.nn.Module):
is_causal
=
self
.
causal
is_causal
=
self
.
causal
cu_seqlens_k
=
cu_seqlens_q
cu_seqlens_k
=
cu_seqlens_q
dropout_p
=
self
.
dropout_p
else
:
else
:
# turn off FA causal mask after first inference autoregressive iteration
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
# only on first autoregressive step q,k,v have same seqlen
is_causal
=
seqlen_q
==
seqlen_k
is_causal
=
seqlen_q
==
seqlen_k
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
device
=
q
.
device
)
self
.
dropout_p
=
0
dropout_p
=
0
output
=
flash_attn_unpadded_func
(
output
=
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
seqlen_q
,
seqlen_k
,
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
seqlen_q
,
seqlen_k
,
self
.
dropout_p
,
dropout_p
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
is_causal
softmax_scale
=
self
.
softmax_scale
,
causal
=
is_causal
)
)
...
@@ -404,8 +404,7 @@ class ParallelAttention(MegatronModule):
...
@@ -404,8 +404,7 @@ class ParallelAttention(MegatronModule):
and returns output of the same size.
and returns output of the same size.
"""
"""
def
__init__
(
self
,
init_method
,
def
__init__
(
self
,
config
,
layer_number
,
output_layer_init_method
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
AttnMaskType
.
padding
):
attn_mask_type
=
AttnMaskType
.
padding
):
super
(
ParallelAttention
,
self
).
__init__
()
super
(
ParallelAttention
,
self
).
__init__
()
...
@@ -413,10 +412,21 @@ class ParallelAttention(MegatronModule):
...
@@ -413,10 +412,21 @@ class ParallelAttention(MegatronModule):
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
params_dtype
=
args
.
params_dtype
self
.
params_dtype
=
config
.
params_dtype
self
.
sequence_parallel
=
args
.
sequence_parallel
self
.
sequence_parallel
=
config
.
sequence_parallel
self
.
group_query_attention
=
args
.
group_query_attention
self
.
num_query_groups
=
args
.
num_query_groups
query_projection_size
=
config
.
kv_channels
*
config
.
num_attention_heads
if
self
.
group_query_attention
:
kv_projection_size
=
args
.
kv_channels
*
args
.
num_query_groups
else
:
kv_projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
self
.
use_flash_attn
=
args
.
use_flash_attn
self
.
use_flash_attn
=
args
.
use_flash_attn
\
and
attention_type
==
AttnType
.
self_attn
\
and
self
.
attn_mask_type
==
AttnMaskType
.
causal
if
self
.
use_flash_attn
:
if
self
.
use_flash_attn
:
if
flash_attn_unpadded_func
is
None
:
if
flash_attn_unpadded_func
is
None
:
raise
ImportError
(
'FlashAttention is not installed, please install with '
raise
ImportError
(
'FlashAttention is not installed, please install with '
...
@@ -428,64 +438,72 @@ class ParallelAttention(MegatronModule):
...
@@ -428,64 +438,72 @@ class ParallelAttention(MegatronModule):
if
rearrange
is
None
:
if
rearrange
is
None
:
raise
ImportError
(
'einops is not installed, please install with pip install einops'
)
raise
ImportError
(
'einops is not installed, please install with pip install einops'
)
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
# Per attention head and per partition values.
# Per attention head and per partition values.
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_attention_head
=
core
.
utils
.
divide
(
self
.
hidden_size_per_attention_head
=
core
.
utils
.
divide
(
projection_size
,
args
.
num_attention_heads
)
query_
projection_size
,
config
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
core
.
utils
.
divide
(
self
.
num_attention_heads_per_partition
=
core
.
utils
.
divide
(
args
.
num_attention_heads
,
world_size
)
config
.
num_attention_heads
,
world_size
)
if
self
.
group_query_attention
:
if
args
.
num_query_groups
%
world_size
!=
0
:
raise
NotImplementedError
(
'Currently the num_query_groups should be '
'a multiple of the tensor parallel size'
)
self
.
num_query_groups_per_partition
=
core
.
utils
.
divide
(
args
.
num_query_groups
,
world_size
)
else
:
self
.
num_query_groups_per_partition
=
self
.
num_attention_heads_per_partition
# Strided linear layer.
# Strided linear layer.
if
attention_type
==
AttnType
.
self_attn
:
if
attention_type
==
AttnType
.
self_attn
:
self
.
query_key_value
=
tensor_parallel
.
ColumnParallelLinear
(
self
.
query_key_value
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
config
.
hidden_size
,
3
*
projection_size
,
query_projection_size
+
2
*
kv_projection_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
args
.
add_bias_linear
,
bias
=
args
.
add_bias_linear
,
gather_output
=
False
,
gather_output
=
False
)
init_method
=
init_method
,
async_tensor_model_parallel_allreduce
=
args
.
async_tensor_model_parallel_allreduce
,
**
_args_to_kwargs
())
else
:
else
:
assert
attention_type
==
AttnType
.
cross_attn
assert
attention_type
==
AttnType
.
cross_attn
self
.
query
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
projection_size
,
bias
=
args
.
add_bias_linear
,
gather_output
=
False
,
init_method
=
init_method
,
async_tensor_model_parallel_allreduce
=
args
.
async_tensor_model_parallel_allreduce
,
**
_args_to_kwargs
())
if
self
.
group_query_attention
:
raise
NotImplementedError
(
"Grouped query attention not implemented for cross-attention."
)
assert
query_projection_size
==
kv_projection_size
self
.
key_value
=
tensor_parallel
.
ColumnParallelLinear
(
self
.
query
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
config
.
hidden_size
,
2
*
projection_size
,
query_projection_size
,
bias
=
args
.
add_bias_linear
,
config
=
config
,
gather_output
=
False
,
init_method
=
config
.
init_method
,
init_method
=
init_method
,
bias
=
config
.
add_bias_linear
,
async_tensor_model_parallel_allreduce
=
args
.
async_tensor_model_parallel_allreduce
,
gather_output
=
False
)
**
_args_to_kwargs
())
self
.
core_attention
=
CoreAttention
(
self
.
layer_number
,
self
.
key_value
=
tensor_parallel
.
ColumnParallelLinear
(
config
.
hidden_size
,
2
*
kv_projection_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
config
.
add_bias_linear
,
gather_output
=
False
)
self
.
core_attention
=
CoreAttention
(
self
.
layer_number
,
config
,
self
.
attn_mask_type
)
self
.
attn_mask_type
)
self
.
checkpoint_core_attention
=
args
.
recompute_granularity
==
'selective'
self
.
checkpoint_core_attention
=
config
.
recompute_granularity
==
'selective'
if
self
.
use_flash_attn
:
if
self
.
use_flash_attn
:
self
.
core_attention_flash
=
FlashSelfAttention
(
self
.
core_attention_flash
=
FlashSelfAttention
(
causal
=
True
,
attention_dropout
=
args
.
attention_dropout
causal
=
True
,
attention_dropout
=
config
.
attention_dropout
)
)
# Output.
# Output.
self
.
dense
=
tensor_parallel
.
RowParallelLinear
(
self
.
dense
=
tensor_parallel
.
RowParallelLinear
(
projection_size
,
query_projection_size
,
args
.
hidden_size
,
config
.
hidden_size
,
config
=
config
,
init_method
=
config
.
output_layer_init_method
,
bias
=
args
.
add_bias_linear
,
bias
=
args
.
add_bias_linear
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
skip_bias_add
=
True
,
**
_args_to_kwargs
())
def
_checkpointed_attention_forward
(
self
,
query_layer
,
key_layer
,
def
_checkpointed_attention_forward
(
self
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
,
value_layer
,
attention_mask
,
...
@@ -510,11 +528,11 @@ class ParallelAttention(MegatronModule):
...
@@ -510,11 +528,11 @@ class ParallelAttention(MegatronModule):
return
hidden_states
return
hidden_states
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
):
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
,
num_attention_heads
):
return
torch
.
empty
(
return
torch
.
empty
(
inference_max_sequence_len
,
inference_max_sequence_len
,
batch_size
,
batch_size
,
self
.
num_attention_heads
_per_partition
,
num_attention_heads
,
self
.
hidden_size_per_attention_head
,
self
.
hidden_size_per_attention_head
,
dtype
=
self
.
params_dtype
,
dtype
=
self
.
params_dtype
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
...
@@ -530,12 +548,15 @@ class ParallelAttention(MegatronModule):
...
@@ -530,12 +548,15 @@ class ParallelAttention(MegatronModule):
is_first_step
=
False
is_first_step
=
False
if
inference_params
:
if
inference_params
:
if
self
.
layer_number
not
in
inference_params
.
key_value_memory_dict
:
if
self
.
layer_number
not
in
inference_params
.
key_value_memory_dict
:
inf_max_seq_len
=
inference_params
.
max_sequence_len
inf_max_seq_len
=
inference_params
.
max_sequence_len
gth
inf_max_batch_size
=
inference_params
.
max_batch_size
inf_max_batch_size
=
inference_params
.
max_batch_size
inference_key_memory
=
self
.
_allocate_memory
(
inference_key_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
)
inf_max_seq_len
,
inf_max_batch_size
,
self
.
num_query_groups_per_partition
)
inference_value_memory
=
self
.
_allocate_memory
(
inference_value_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
)
inf_max_seq_len
,
inf_max_batch_size
,
self
.
num_query_groups_per_partition
)
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_key_memory
,
inference_value_memory
)
inference_key_memory
,
inference_value_memory
)
is_first_step
=
True
is_first_step
=
True
...
@@ -546,21 +567,36 @@ class ParallelAttention(MegatronModule):
...
@@ -546,21 +567,36 @@ class ParallelAttention(MegatronModule):
# =====================
# =====================
# Query, Key, and Value
# Query, Key, and Value
# =====================
# =====================
if
self
.
attention_type
==
AttnType
.
self_attn
:
if
self
.
attention_type
==
AttnType
.
self_attn
:
# Attention heads [sq, b, h] --> [sq, b,
(np * 3
* hn)]
# Attention heads [sq, b, h] --> [sq, b,
ng * (np/ng + 2)
* hn)]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
(
(
self
.
num_attention_heads_per_partition
,
self
.
num_query_groups_per_partition
,
3
*
self
.
hidden_size_per_attention_head
)
(
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
+
2
)
*
self
.
hidden_size_per_attention_head
),
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [sq, b, n
p, 3
* hn] -->
3
[sq, b, n
p
, hn]
# [sq, b, n
g, (np/ng + 2)
* hn] --> [sq, b, n
g, np/ng * hn], [sq, b, ng, hn], [sq, b, ng
, hn]
(
query_layer
,
(
query_layer
,
key_layer
,
key_layer
,
value_layer
)
=
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
value_layer
)
=
torch
.
split
(
mixed_x_layer
,
[
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
*
self
.
hidden_size_per_attention_head
),
self
.
hidden_size_per_attention_head
,
self
.
hidden_size_per_attention_head
],
dim
=
3
)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
query_layer
=
query_layer
.
view
(
query_layer
.
size
(
0
),
query_layer
.
size
(
1
),
-
1
,
self
.
hidden_size_per_attention_head
)
else
:
else
:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer
,
_
=
self
.
key_value
(
encoder_output
)
mixed_kv_layer
,
_
=
self
.
key_value
(
encoder_output
)
...
@@ -568,19 +604,19 @@ class ParallelAttention(MegatronModule):
...
@@ -568,19 +604,19 @@ class ParallelAttention(MegatronModule):
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape
=
mixed_kv_layer
.
size
()[:
-
1
]
+
\
new_tensor_shape
=
mixed_kv_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
(
self
.
num_attention_heads_per_partition
,
2
*
self
.
hidden_size_per_attention_head
)
2
*
self
.
hidden_size_per_attention_head
)
mixed_kv_layer
=
mixed_kv_layer
.
view
(
*
new_tensor_shape
)
mixed_kv_layer
=
mixed_kv_layer
.
view
(
*
new_tensor_shape
)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(
key_layer
,
(
key_layer
,
value_layer
)
=
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
value_layer
)
=
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
# Attention head [sq, b, h] --> [sq, b, hp]
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer
,
_
=
self
.
query
(
hidden_states
)
query_layer
,
_
=
self
.
query
(
hidden_states
)
# [sq, b, hp] --> [sq, b, np, hn]
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape
=
query_layer
.
size
()[:
-
1
]
+
\
new_tensor_shape
=
query_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
self
.
hidden_size_per_attention_head
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ==================================
# ==================================
...
@@ -632,11 +668,20 @@ class ParallelAttention(MegatronModule):
...
@@ -632,11 +668,20 @@ class ParallelAttention(MegatronModule):
k_pos_emb
=
k_pos_emb
[:
sequence_end
,
:,
:,
:]
k_pos_emb
=
k_pos_emb
[:
sequence_end
,
:,
:,
:]
rotary_pos_emb
=
(
q_pos_emb
,
k_pos_emb
)
rotary_pos_emb
=
(
q_pos_emb
,
k_pos_emb
)
# ==================================
# ==================================
# core attention computation
# core attention computation
# ==================================
# ==================================
# expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
key_layer
=
key_layer
.
repeat_interleave
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
,
dim
=
2
)
value_layer
=
value_layer
.
repeat_interleave
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
,
dim
=
2
)
# apply relative positional encoding (rotary embedding)
# apply relative positional encoding (rotary embedding)
if
rotary_pos_emb
is
not
None
:
if
rotary_pos_emb
is
not
None
:
q_pos_emb
,
k_pos_emb
=
rotary_pos_emb
q_pos_emb
,
k_pos_emb
=
rotary_pos_emb
...
@@ -711,10 +756,11 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -711,10 +756,11 @@ class ParallelTransformerLayer(MegatronModule):
output of the same size.
output of the same size.
"""
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
def
__init__
(
self
,
config
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
drop_path_rate
=
0.
):
drop_path_rate
=
0.
):
# retriever=None):
args
=
get_args
()
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
...
@@ -722,57 +768,59 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -722,57 +768,59 @@ class ParallelTransformerLayer(MegatronModule):
self
.
layer_type
=
layer_type
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_layernorm
\
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
=
config
.
apply_residual_connection_post_layernorm
self
.
bf16
=
args
.
bf16
self
.
bf16
=
config
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
fp32_residual_connection
=
config
.
fp32_residual_connection
# Layernorm on the input data.
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
config
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
config
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
,
sequence_parallel
=
config
.
sequence_parallel
,
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
# Self attention.
# Self attention.
self
.
self_attention
=
ParallelAttention
(
self
.
self_attention
=
ParallelAttention
(
init_method
,
config
,
output_layer_init_method
,
layer_number
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
)
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
hidden_dropout
=
config
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
bias_dropout_fusion
=
config
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
else
None
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
else
None
# Layernorm on the attention output
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
config
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
config
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_
persist_layer_norm
,
no_persist_layer_norm
=
not
config
.
persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
,
sequence_parallel
=
config
.
sequence_parallel
,
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
if
self
.
layer_type
==
LayerType
.
decoder
:
# Cross attention.
if
self
.
layer_type
in
(
LayerType
.
decoder
,
LayerType
.
retro_decoder
,
LayerType
.
retro_decoder_with_retriever
,
LayerType
.
retro_encoder
):
self
.
inter_attention
=
ParallelAttention
(
self
.
inter_attention
=
ParallelAttention
(
init_method
,
config
,
output_layer_init_method
,
layer_number
,
layer_number
,
attention_type
=
AttnType
.
cross_attn
)
attention_type
=
AttnType
.
cross_attn
)
# Layernorm on the attention output.
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
config
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
config
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_
persist_layer_norm
,
no_persist_layer_norm
=
not
config
.
persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
,
sequence_parallel
=
config
.
sequence_parallel
,
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
# MLP
# MLP
if
args
.
num_experts
is
not
None
:
if
args
.
num_experts
is
not
None
:
self
.
mlp
=
SwitchMLP
(
init_method
,
output_layer_init_method
)
self
.
mlp
=
SwitchMLP
(
config
)
else
:
else
:
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
self
.
mlp
=
ParallelMLP
(
config
)
# Set bias+dropout+add fusion grad_enable execution handler.
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
...
@@ -781,13 +829,245 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -781,13 +829,245 @@ class ParallelTransformerLayer(MegatronModule):
self
.
bias_dropout_add_exec_handler
=
\
self
.
bias_dropout_add_exec_handler
=
\
nullcontext
if
use_nvfuser
else
torch
.
enable_grad
nullcontext
if
use_nvfuser
else
torch
.
enable_grad
if
args
.
retro_add_retriever
:
retro_args
=
get_retro_args
()
self
.
retro_num_neighbors
=
args
.
retro_num_neighbors
self
.
retro_chunk_length
=
retro_args
.
retro_gpt_chunk_length
self
.
retro_retrieved_length
=
retro_args
.
retro_gpt_retrieved_length
# Retriever (bi-directional transformer with cross attention)
if
layer_type
==
LayerType
.
retro_decoder_with_retriever
:
self
.
retriever
=
ParallelTransformer
(
config
=
config
,
model_type
=
ModelType
.
retro_encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
True
,
post_process
=
False
,
)
self
.
_retriever_key
=
'retriever'
else
:
self
.
retriever
=
None
def
default_decoder_cross_attention
(
self
,
encoder_output
,
enc_dec_attn_mask
,
layernorm_input
,
layernorm_output
,
bias_dropout_add_func
):
'''Cross attention for a standard encoder-decoder model.'''
# Attention.
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
layernorm_output
,
enc_dec_attn_mask
,
encoder_output
=
encoder_output
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
if
attention_bias
is
not
None
:
attention_bias
=
attention_bias
.
expand_as
(
residual
)
# Bias-dropout-add.
with
self
.
bias_dropout_add_exec_handler
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
,
residual
,
self
.
hidden_dropout
)
# Layer norm.
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
return
layernorm_input
,
layernorm_output
def
retro_encoder_cross_attention
(
self
,
retriever_output
,
layernorm_input
,
layernorm_output
,
bias_dropout_add_func
):
"""Cross attention for Retro encoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
"""
ns
,
bs
,
d
=
layernorm_output
.
shape
# [r, bs * l * k, d]
# Divide sequence dimension into chunks.
chunked_outputs
=
layernorm_output
.
reshape
(
self
.
retro_retrieved_length
,
-
1
,
self
.
retro_num_neighbors
,
d
)
chunked_outputs_before_layer_norm
=
\
layernorm_input
.
reshape
(
self
.
retro_retrieved_length
,
-
1
,
self
.
retro_num_neighbors
,
d
)
# [r, bs*l, k, d]
# Per-chunk attention.
layernorm_inputs
=
[]
layernorm_outputs
=
[]
for
k
in
range
(
self
.
retro_num_neighbors
):
# Attention.
chunked_output
=
chunked_outputs
[:,:,
k
].
contiguous
()
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
chunked_output
,
# Q (neighbor embedding)
None
,
encoder_output
=
retriever_output
)
# K, V (hidden act)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
chunked_output
else
:
residual
=
chunked_outputs_before_layer_norm
[:,:,
k
]
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
None
if
attention_bias
is
None
else
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
layernorm_inputs
.
append
(
layernorm_input
)
# Layer norm.
layernorm_output
=
\
self
.
post_inter_attention_layernorm
(
layernorm_input
)
layernorm_outputs
.
append
(
layernorm_output
)
# Concatenate layer norms.
# layernorm_input : [r, k * bs * l, d]
# layernorm_output : [r, k * bs * l, d]
layernorm_input
=
\
torch
.
stack
(
layernorm_inputs
,
dim
=
1
).
reshape
(
ns
,
bs
,
d
)
layernorm_output
=
\
torch
.
stack
(
layernorm_outputs
,
dim
=
1
).
reshape
(
ns
,
bs
,
d
)
return
layernorm_input
,
layernorm_output
def
retro_decoder_cross_attention
(
self
,
retriever_input
,
retriever_output
,
retriever_attn_mask
,
layernorm_input
,
layernorm_output
,
inference_params
,
bias_dropout_add_func
):
"""Cross attention for Retro decoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
m : Number of tokens per chunk.
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
"""
ns
,
bs
,
d
=
layernorm_output
.
shape
l
=
int
(
np
.
ceil
(
ns
/
self
.
retro_chunk_length
))
# Retrieve neighbors.
if
self
.
layer_type
==
LayerType
.
retro_decoder_with_retriever
:
first_ns
=
ns
%
self
.
retro_chunk_length
if
first_ns
>
0
:
raise
Exception
(
"test this case."
)
first_chunk
,
rest_chunk
=
\
layernorm_output
[:
first_ns
],
layernorm_output
[
first_ns
:]
first_chunk
=
torch
.
nn
.
functional
.
pad
(
first_chunk
,
(
0
,
0
,
0
,
0
,
0
,
self
.
retro_chunk_length
-
first_ns
),
'constant'
,
0
)
chunked_output
=
\
torch
.
cat
((
first_chunk
,
rest_chunk
),
dim
=
0
)
# [l * m, bs, d]
else
:
chunked_output
=
layernorm_output
# [l * m, bs, d]
chunked_output
=
chunked_output
\
.
reshape
(
l
,
self
.
retro_chunk_length
,
bs
,
d
)
\
.
permute
(
1
,
2
,
0
,
3
)
\
.
reshape
(
self
.
retro_chunk_length
,
bs
*
l
,
d
)
\
.
contiguous
()
# Get Encoder Output
retriever_output
=
self
.
retriever
(
hidden_states
=
retriever_input
,
attention_mask
=
retriever_attn_mask
,
retriever_output
=
chunked_output
,
retriever_attn_mask
=
retriever_attn_mask
,
inference_params
=
inference_params
)
# [r, k * bs * l , d]
retriever_output
=
retriever_output
.
reshape
(
self
.
retro_retrieved_length
*
self
.
retro_num_neighbors
,
bs
*
l
,
d
)
# [r * k, bs * l, d]
# Chunks.
pad
=
(
ns
-
1
)
%
self
.
retro_chunk_length
attending_chunks
=
layernorm_output
[
pad
:]
padded_chunks
=
torch
.
nn
.
functional
.
pad
(
attending_chunks
,
(
0
,
0
,
0
,
0
,
0
,
self
.
retro_chunk_length
-
1
),
'constant'
,
0
)
padded_chunked_output
=
padded_chunks
\
.
reshape
(
l
,
self
.
retro_chunk_length
,
bs
,
d
)
\
.
permute
(
1
,
2
,
0
,
3
)
padded_chunked_output
=
padded_chunked_output
.
reshape
(
self
.
retro_chunk_length
,
bs
*
l
,
d
).
contiguous
()
# Encoder output.
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
padded_chunked_output
,
None
,
encoder_output
=
retriever_output
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
None
if
attention_bias
is
None
else
attention_bias
.
expand_as
(
attention_output
),
torch
.
zeros_like
(
attention_output
),
self
.
hidden_dropout
)
layernorm_input
=
layernorm_input
\
.
reshape
(
self
.
retro_chunk_length
,
bs
,
l
,
d
)
\
.
permute
(
2
,
0
,
1
,
3
)
# [l, m, bs, d]
layernorm_input
=
layernorm_input
.
reshape
(
self
.
retro_chunk_length
*
l
,
bs
,
d
)
layernorm_input
=
torch
.
nn
.
functional
.
pad
(
layernorm_input
,
(
0
,
0
,
0
,
0
,
pad
,
0
),
'constant'
,
0
)[:
ns
]
# [ns, b, d]
layernorm_input
=
layernorm_input
+
residual
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
return
retriever_output
,
layernorm_input
,
layernorm_output
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
):
retriever_input
=
None
,
retriever_output
=
None
,
retriever_attn_mask
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
):
# hidden_states: [s, b, h]
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
# Self attention.
attention_output
,
attention_bias
=
\
attention_output
,
attention_bias
=
\
self
.
self_attention
(
self
.
self_attention
(
...
@@ -832,29 +1112,38 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -832,29 +1112,38 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
if
self
.
layer_type
==
LayerType
.
decoder
:
# Cross attention.
attention_output
,
attention_bias
=
\
if
self
.
layer_type
==
LayerType
.
encoder
:
self
.
inter_attention
(
layernorm_output
,
pass
enc_dec_attn_mask
,
elif
self
.
layer_type
==
LayerType
.
decoder
:
encoder_output
=
encoder_output
)
layernorm_input
,
layernorm_output
=
\
# residual connection
self
.
default_decoder_cross_attention
(
if
self
.
apply_residual_connection_post_layernorm
:
encoder_output
,
residual
=
layernorm_output
enc_dec_attn_mask
,
else
:
layernorm_input
,
residual
=
layernorm_input
layernorm_output
,
bias_dropout_add_func
)
if
attention_bias
is
not
None
:
elif
self
.
layer_type
==
LayerType
.
retro_encoder
:
attention_bias
=
attention_bias
.
expand_as
(
residual
)
layernorm_input
,
layernorm_output
=
\
self
.
retro_encoder_cross_attention
(
with
self
.
bias_dropout_add_exec_handler
():
retriever_output
,
layernorm_input
=
bias_dropout_add_func
(
layernorm_input
,
attention_output
,
layernorm_output
,
attention_bias
,
bias_dropout_add_func
)
residual
,
elif
self
.
layer_type
in
(
LayerType
.
retro_decoder
,
self
.
hidden_dropout
)
LayerType
.
retro_decoder_with_retriever
):
retriever_output
,
layernorm_input
,
layernorm_output
=
\
# Layer norm post the decoder attention
self
.
retro_decoder_cross_attention
(
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
retriever_input
,
retriever_output
,
retriever_attn_mask
,
layernorm_input
,
layernorm_output
,
inference_params
,
bias_dropout_add_func
)
else
:
raise
Exception
(
"Unsupported layer type, '%s'."
%
self
.
layer_type
.
name
)
# MLP.
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
...
@@ -893,7 +1182,10 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -893,7 +1182,10 @@ class ParallelTransformerLayer(MegatronModule):
training
=
self
.
training
)
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
output
=
residual
+
self
.
drop_path
(
out
)
return
output
if
self
.
layer_type
==
LayerType
.
retro_decoder_with_retriever
:
return
output
,
retriever_output
else
:
return
output
class
NoopTransformerLayer
(
MegatronModule
):
class
NoopTransformerLayer
(
MegatronModule
):
...
@@ -922,9 +1214,12 @@ class NoopTransformerLayer(MegatronModule):
...
@@ -922,9 +1214,12 @@ class NoopTransformerLayer(MegatronModule):
return
hidden_states
.
clone
()
return
hidden_states
.
clone
()
def
_get_num_layers
(
args
,
is_encoder_and_decoder_
model
,
is_decoder
=
False
):
def
_get_num_layers
(
args
,
model
_type
,
is_decoder
=
False
):
"""Compute the number of transformer layers resident on the current rank."""
"""Compute the number of transformer layers resident on the current rank."""
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
is_encoder_and_decoder_model
=
(
model_type
==
ModelType
.
encoder_and_decoder
)
if
model_type
==
ModelType
.
retro_encoder
:
num_layers
=
args
.
retro_encoder_layers
elif
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
is_encoder_and_decoder_model
:
if
is_encoder_and_decoder_model
:
assert
args
.
pipeline_model_parallel_split_rank
is
not
None
assert
args
.
pipeline_model_parallel_split_rank
is
not
None
...
@@ -974,51 +1269,91 @@ def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
...
@@ -974,51 +1269,91 @@ def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
return
num_layers
return
num_layers
def
_get_layer_type
(
model_type
,
default_layer_type
,
retro_layer_numbers
,
layer_number
):
args
=
get_args
()
if
args
.
retro_add_retriever
and
layer_number
in
retro_layer_numbers
:
if
model_type
==
ModelType
.
retro_decoder
:
return
LayerType
.
retro_decoder_with_retriever
\
if
layer_number
==
retro_layer_numbers
[
0
]
\
else
LayerType
.
retro_decoder
elif
model_type
==
ModelType
.
retro_encoder
:
return
LayerType
.
retro_encoder
else
:
raise
Exception
(
"Unsupported model type, '%s'."
%
model_type
)
else
:
return
default_layer_type
class
ParallelTransformer
(
MegatronModule
):
class
ParallelTransformer
(
MegatronModule
):
"""Transformer class."""
"""Transformer class."""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
def
__init__
(
self
,
config
,
layer_type
=
LayerType
.
encoder
,
model_type
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
post_layer_norm
=
True
,
post_layer_norm
=
True
,
pre_process
=
True
,
post_process
=
True
,
pre_process
=
True
,
post_process
=
True
,
drop_path_rate
=
0.0
):
drop_path_rate
=
0.0
):
super
(
ParallelTransformer
,
self
).
__init__
()
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
layer_type
=
layer_type
self
.
layer_type
=
layer_type
self
.
model_type
=
args
.
model_type
self
.
model_type
=
model_type
self
.
bf16
=
args
.
bf16
self
.
bf16
=
config
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
fp32_residual_connection
=
config
.
fp32_residual_connection
self
.
post_layer_norm
=
post_layer_norm
self
.
post_layer_norm
=
post_layer_norm
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
self
.
input_tensor
=
None
self
.
drop_path_rate
=
drop_path_rate
self
.
drop_path_rate
=
drop_path_rate
self
.
transformer_impl
=
args
.
transformer_impl
self
.
transformer_impl
=
args
.
transformer_impl
self
.
retro_add_retriever
=
args
.
retro_add_retriever
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
self
.
recompute_granularity
=
args
.
recompute_granularity
self
.
recompute_granularity
=
config
.
recompute_granularity
self
.
recompute_method
=
args
.
recompute_method
self
.
recompute_method
=
config
.
recompute_method
self
.
recompute_num_layers
=
args
.
recompute_num_layers
self
.
recompute_num_layers
=
config
.
recompute_num_layers
self
.
distribute_saved_activations
=
\
self
.
distribute_saved_activations
=
\
args
.
distribute_saved_activations
and
not
args
.
sequence_parallel
config
.
distribute_saved_activations
and
not
config
.
sequence_parallel
self
.
sequence_parallel
=
args
.
sequence_parallel
self
.
sequence_parallel
=
config
.
sequence_parallel
# Transformer Engine Init.
# Transformer Engine Init.
self
.
transformer_engine_v_0_10
=
False
self
.
transformer_engine_v_0_11
=
False
self
.
transformer_engine_v_0_8
=
False
if
self
.
transformer_impl
==
'transformer_engine'
:
if
self
.
transformer_impl
==
'transformer_engine'
:
global
transformer_engine
global
transformer_engine
import
transformer_engine
import
transformer_engine
self
.
use_fp8
=
args
.
fp8_e4m3
or
args
.
fp8_hybrid
from
importlib.metadata
import
version
from
pkg_resources
import
packaging
te_version
=
packaging
.
version
.
Version
(
version
(
"transformer-engine"
))
if
te_version
>=
packaging
.
version
.
Version
(
"0.8.0"
):
self
.
transformer_engine_v_0_8
=
True
if
te_version
>=
packaging
.
version
.
Version
(
"0.10.0"
):
self
.
transformer_engine_v_0_10
=
True
if
te_version
>=
packaging
.
version
.
Version
(
"0.11.0"
):
self
.
transformer_engine_v_0_11
=
True
del
version
,
packaging
assert
not
args
.
squared_relu
,
"TransformerEngine does not support squared relu activation."
self
.
use_fp8
=
args
.
fp8
is
not
None
self
.
fp8_recipe
=
None
self
.
fp8_recipe
=
None
self
.
fp8_group
=
None
self
.
fp8_group
=
None
if
self
.
use_fp8
:
if
self
.
use_fp8
:
self
.
fp8_group
=
mpu
.
get_data_parallel_group
()
assert
args
.
transformer_impl
==
'transformer_engine'
,
\
if
args
.
fp8_e4m3
:
'transformer-engine required for fp8 training and inference'
self
.
fp8_group
=
mpu
.
get_amax_reduction_group
()
if
args
.
fp8
==
"e4m3"
:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
elif
args
.
fp8
_
hybrid
:
elif
args
.
fp8
==
"
hybrid
"
:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
HYBRID
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
HYBRID
else
:
raise
ValueError
(
"The DelayedScaling recipe only supports E4M3 and HYBRID formats."
)
self
.
fp8_recipe
=
transformer_engine
.
common
.
recipe
.
DelayedScaling
(
self
.
fp8_recipe
=
transformer_engine
.
common
.
recipe
.
DelayedScaling
(
margin
=
args
.
fp8_margin
,
margin
=
args
.
fp8_margin
,
interval
=
args
.
fp8_interval
,
interval
=
args
.
fp8_interval
,
...
@@ -1030,63 +1365,87 @@ class ParallelTransformer(MegatronModule):
...
@@ -1030,63 +1365,87 @@ class ParallelTransformer(MegatronModule):
self
.
num_microbatches_in_previous_step
=
-
1
self
.
num_microbatches_in_previous_step
=
-
1
self
.
microbatch_count
=
0
self
.
microbatch_count
=
0
self
.
checkpoint_core_attention
=
args
.
recompute_granularity
==
'selective'
self
.
checkpoint_core_attention
=
config
.
recompute_granularity
==
'selective'
# Number of layers.
# Number of layers.
self
.
num_layers
=
_get_num_layers
(
self
.
num_layers
=
_get_num_layers
(
args
,
model_type
,
args
,
layer_type
==
LayerType
.
decoder
)
args
.
model_type
==
ModelType
.
encoder_and_decoder
,
layer_type
==
LayerType
.
decoder
)
self
.
drop_path_rates
=
[
rate
.
item
()
for
rate
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
config
.
num_layers
)]
self
.
drop_path_rates
=
[
rate
.
item
()
for
rate
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
args
.
num_layers
)]
self
.
retro_layer_numbers
=
None
if
model_type
==
ModelType
.
retro_decoder
:
retro_layer_start
=
6
if
config
.
num_layers
<=
15
else
9
self
.
retro_layer_numbers
=
\
np
.
arange
(
retro_layer_start
,
args
.
num_layers
+
1
,
3
).
tolist
()
if
model_type
==
ModelType
.
retro_encoder
:
self
.
retro_layer_numbers
=
[
1
]
# Transformer layers.
# Transformer layers.
if
args
.
retro_add_retriever
:
assert
self
.
recompute_granularity
!=
'full'
,
\
"Full recompute not supported for Retro."
assert
args
.
transformer_impl
==
'local'
,
\
"Transformer engine does not support Retro layers."
def
build_layer
(
layer_number
):
def
build_layer
(
layer_number
):
if
args
.
transformer_impl
==
'local'
:
if
args
.
transformer_impl
==
'local'
:
current_layer_type
=
_get_layer_type
(
model_type
,
layer_type
,
self
.
retro_layer_numbers
,
layer_number
)
return
ParallelTransformerLayer
(
return
ParallelTransformerLayer
(
init_method
,
config
,
output_layer_init_method
,
layer_number
,
layer_number
,
layer_type
=
layer_type
,
layer_type
=
current_
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
,
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
else
:
else
:
# This argument is only available from TE v0.10 onwards.
extra_transformer_engine_kwargs
=
{}
if
self
.
transformer_engine_v_0_8
:
extra_transformer_engine_kwargs
[
"bias"
]
=
args
.
add_bias_linear
if
self
.
transformer_engine_v_0_10
:
extra_transformer_engine_kwargs
[
"activation"
]
=
"swiglu"
if
args
.
swiglu
else
"gelu"
if
self
.
transformer_engine_v_0_11
:
extra_transformer_engine_kwargs
[
"normalization"
]
=
args
.
normalization
return
transformer_engine
.
pytorch
.
TransformerLayer
(
return
transformer_engine
.
pytorch
.
TransformerLayer
(
args
.
hidden_size
,
config
.
hidden_size
,
args
.
ffn_hidden_size
,
config
.
ffn_hidden_size
,
args
.
num_attention_heads
,
config
.
num_attention_heads
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
layernorm_epsilon
=
config
.
layernorm_epsilon
,
hidden_dropout
=
args
.
hidden_dropout
,
hidden_dropout
=
config
.
hidden_dropout
,
attention_dropout
=
args
.
attention_dropout
,
attention_dropout
=
config
.
attention_dropout
,
init_method
=
init_method
,
init_method
=
config
.
init_method
,
output_layer_init_method
=
output_layer_init_method
,
output_layer_init_method
=
config
.
output_layer_init_method
,
layer_number
=
layer_number
,
layer_number
=
layer_number
,
kv_channels
=
args
.
kv_channels
,
kv_channels
=
config
.
kv_channels
,
self_attn_mask_type
=
self_attn_mask_type
.
name
,
self_attn_mask_type
=
self_attn_mask_type
.
name
,
tp_group
=
mpu
.
get_tensor_model_parallel_group
(),
tp_group
=
mpu
.
get_tensor_model_parallel_group
(),
get_rng_state_tracker
=
tensor_parallel
.
get_cuda_rng_tracker
,
get_rng_state_tracker
=
tensor_parallel
.
get_cuda_rng_tracker
,
fuse_wgrad_accumulation
=
args
.
gradient_accumulation_fusion
,
fuse_wgrad_accumulation
=
config
.
gradient_accumulation_fusion
,
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
,
apply_query_key_layer_scaling
=
config
.
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
,
attention_softmax_in_fp32
=
config
.
attention_softmax_in_fp32
,
seq_length
=
args
.
seq_length
,
seq_length
=
args
.
seq_length
,
micro_batch_size
=
args
.
micro_batch_size
,
micro_batch_size
=
args
.
micro_batch_size
,
sequence_parallel
=
args
.
sequence_parallel
,
sequence_parallel
=
config
.
sequence_parallel
,
params_dtype
=
args
.
params_dtype
,
params_dtype
=
config
.
params_dtype
,
apply_residual_connection_post_layernorm
=
args
.
apply_residual_connection_post_layernorm
,
apply_residual_connection_post_layernorm
=
config
.
apply_residual_connection_post_layernorm
,
output_layernorm
=
False
,
output_layernorm
=
False
,
layer_type
=
"encoder"
,
layer_type
=
"encoder"
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
],
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
],
set_parallel_mode
=
True
,
set_parallel_mode
=
True
,
fuse_qkv_params
=
True
)
fuse_qkv_params
=
True
,
**
extra_transformer_engine_kwargs
)
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
if
config
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
assert
config
.
num_layers
%
config
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'num_layers_per_stage must be divisible by '
\
'virtual_pipeline_model_parallel_size'
'virtual_pipeline_model_parallel_size'
assert
args
.
model_type
!=
ModelType
.
encoder_and_decoder
assert
args
.
model_type
!=
ModelType
.
encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage,
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
# divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
self
.
num_layers
=
self
.
num_layers
//
config
.
virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 0: [0] [2] [4] [6]
...
@@ -1096,7 +1455,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -1096,7 +1455,7 @@ class ParallelTransformer(MegatronModule):
# Stage 0: [0, 1] [4, 5]
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
# Stage 1: [2, 3] [6, 7]
offset
=
mpu
.
get_virtual_pipeline_model_parallel_rank
()
*
(
offset
=
mpu
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
config
.
num_layers
//
config
.
virtual_pipeline_model_parallel_size
)
+
\
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
else
:
# Each stage gets a contiguous set of layers.
# Each stage gets a contiguous set of layers.
...
@@ -1126,13 +1485,24 @@ class ParallelTransformer(MegatronModule):
...
@@ -1126,13 +1485,24 @@ class ParallelTransformer(MegatronModule):
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
# Update dropout rate for Retro encoder.
if
model_type
==
ModelType
.
retro_encoder
:
for
layer
in
self
.
layers
:
if
layer
.
self_attention
.
use_flash_attn
:
layer
.
self_attention
.
core_attention_flash
.
dropout_p
=
\
torch
.
nn
.
Dropout
(
args
.
retro_encoder_attention_dropout
)
else
:
layer
.
self_attention
.
core_attention
.
attention_dropout
.
p
=
\
args
.
retro_encoder_attention_dropout
layer
.
hidden_dropout
=
args
.
retro_encoder_hidden_dropout
if
self
.
post_process
and
self
.
post_layer_norm
:
if
self
.
post_process
and
self
.
post_layer_norm
:
# Final layer norm before output.
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
config
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
config
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
,
sequence_parallel
=
config
.
sequence_parallel
,
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
def
_get_layer
(
self
,
layer_number
):
def
_get_layer
(
self
,
layer_number
):
...
@@ -1142,40 +1512,42 @@ class ParallelTransformer(MegatronModule):
...
@@ -1142,40 +1512,42 @@ class ParallelTransformer(MegatronModule):
encoder_output
,
enc_dec_attn_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
,
is_first_microbatch
):
rotary_pos_emb
,
is_first_microbatch
):
"""Forward method with activation checkpointing."""
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
,
is_transformer_engine
=
False
):
def
custom
(
start
,
end
):
def
custom_forward
(
*
args
,
**
kwargs
):
def
custom_forward
(
*
args
,
**
kwargs
):
x_
,
*
args
=
args
x_
,
*
args
=
args
for
index
in
range
(
start
,
end
):
for
index
in
range
(
start
,
end
):
layer
=
self
.
_get_layer
(
index
)
layer
=
self
.
_get_layer
(
index
)
x_
=
layer
(
x_
,
*
args
,
**
kwargs
)
x_
=
layer
(
x_
,
*
args
,
**
kwargs
)
return
x_
return
x_
def
custom_forward_transformer_engine
(
*
args
,
**
kwargs
):
return
custom_forward
return
custom_forward
(
*
args
,
is_first_microbatch
=
is_first_microbatch
,
**
kwargs
)
if
not
is_transformer_engine
:
te_forward_kwargs
=
{}
return
custom_forward
if
self
.
transformer_impl
==
'transformer_engine'
:
else
:
te_forward_kwargs
[
'is_first_microbatch'
]
=
is_first_microbatch
return
custom_forward_transformer_engine
if
self
.
transformer_engine_v_0_10
:
te_forward_kwargs
[
'rotary_pos_emb'
]
=
rotary_pos_emb
if
self
.
recompute_method
==
'uniform'
:
if
self
.
recompute_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and
checkpoint
# Uniformly divide the total number of Transformer layers and
# the input activation of each divided chunk.
#
checkpoint
the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
# A method to further reduce memory usage reducing checkpoints.
l
=
0
l
=
0
while
l
<
self
.
num_layers
:
while
l
<
self
.
num_layers
:
if
self
.
transformer_impl
==
'transformer_engine'
:
if
self
.
transformer_impl
==
'transformer_engine'
:
hidden_states
=
transformer_engine
.
pytorch
.
distributed
.
checkpoint
(
hidden_states
=
transformer_engine
.
pytorch
.
checkpoint
(
custom
(
l
,
l
+
self
.
recompute_num_layers
,
is_transformer_engine
=
True
),
custom
(
l
,
l
+
self
.
recompute_num_layers
),
self
.
distribute_saved_activations
,
self
.
distribute_saved_activations
,
tensor_parallel
.
get_cuda_rng_tracker
,
tensor_parallel
.
get_cuda_rng_tracker
,
mpu
.
get_tensor_model_parallel_group
(),
mpu
.
get_tensor_model_parallel_group
(),
hidden_states
,
attention_mask
,
encoder_output
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
)
enc_dec_attn_mask
,
**
te_forward_kwargs
)
else
:
else
:
hidden_states
=
tensor_parallel
.
checkpoint
(
hidden_states
=
tensor_parallel
.
checkpoint
(
custom
(
l
,
l
+
self
.
recompute_num_layers
),
custom
(
l
,
l
+
self
.
recompute_num_layers
),
self
.
distribute_saved_activations
,
self
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
hidden_states
,
attention_mask
,
enc_dec_attn_mask
,
rotary_pos_emb
)
encoder_output
,
enc_dec_attn_mask
,
None
,
None
,
None
,
None
,
rotary_pos_emb
)
l
+=
self
.
recompute_num_layers
l
+=
self
.
recompute_num_layers
...
@@ -1186,28 +1558,30 @@ class ParallelTransformer(MegatronModule):
...
@@ -1186,28 +1558,30 @@ class ParallelTransformer(MegatronModule):
for
l
in
range
(
self
.
num_layers
):
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
recompute_num_layers
:
if
l
<
self
.
recompute_num_layers
:
if
self
.
transformer_impl
==
'transformer_engine'
:
if
self
.
transformer_impl
==
'transformer_engine'
:
hidden_states
=
transformer_engine
.
pytorch
.
distributed
.
checkpoint
(
hidden_states
=
transformer_engine
.
pytorch
.
checkpoint
(
custom
(
l
,
l
+
1
,
is_transformer_engine
=
True
),
custom
(
l
,
l
+
1
),
self
.
distribute_saved_activations
,
self
.
distribute_saved_activations
,
tensor_parallel
.
get_cuda_rng_tracker
,
tensor_parallel
.
get_cuda_rng_tracker
,
mpu
.
get_tensor_model_parallel_group
(),
mpu
.
get_tensor_model_parallel_group
(),
hidden_states
,
attention_mask
,
encoder_output
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
)
enc_dec_attn_mask
,
**
te_forward_kwargs
)
else
:
else
:
hidden_states
=
tensor_parallel
.
checkpoint
(
hidden_states
=
tensor_parallel
.
checkpoint
(
custom
(
l
,
l
+
1
),
custom
(
l
,
l
+
1
),
self
.
distribute_saved_activations
,
self
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
hidden_states
,
attention_mask
,
enc_dec_attn_mask
,
rotary_pos_emb
)
encoder_output
,
enc_dec_attn_mask
,
None
,
None
,
None
,
None
,
rotary_pos_emb
)
else
:
else
:
if
self
.
transformer_impl
==
'transformer_engine'
:
if
self
.
transformer_impl
==
'transformer_engine'
:
hidden_states
=
custom
(
l
,
l
+
1
,
is_transformer_engine
=
True
)(
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
)
enc_dec_attn_mask
,
**
te_forward_kwargs
)
else
:
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
hidden_states
,
attention_mask
,
enc_dec_attn_mask
,
rotary_pos_emb
)
encoder_output
,
enc_dec_attn_mask
,
None
,
None
,
None
,
None
,
rotary_pos_emb
)
else
:
else
:
raise
ValueError
(
"Invalid activation recompute method."
)
raise
ValueError
(
"Invalid activation recompute method."
)
...
@@ -1225,7 +1599,11 @@ class ParallelTransformer(MegatronModule):
...
@@ -1225,7 +1599,11 @@ class ParallelTransformer(MegatronModule):
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
):
retriever_input
=
None
,
retriever_output
=
None
,
retriever_attn_mask
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
):
# hidden_states: [s, b, h]
# hidden_states: [s, b, h]
# Checks.
# Checks.
...
@@ -1258,11 +1636,13 @@ class ParallelTransformer(MegatronModule):
...
@@ -1258,11 +1636,13 @@ class ParallelTransformer(MegatronModule):
keep_graph
=
True
,
keep_graph
=
True
,
)
)
# RNG context.
if
self
.
sequence_parallel
:
if
self
.
sequence_parallel
:
rng_context
=
tensor_parallel
.
get_cuda_rng_tracker
().
fork
()
rng_context
=
tensor_parallel
.
get_cuda_rng_tracker
().
fork
()
else
:
else
:
rng_context
=
nullcontext
()
rng_context
=
nullcontext
()
# Forward layers.
with
rng_context
:
with
rng_context
:
# The fp8_autocast context manager is a no-op when enabled=True
# The fp8_autocast context manager is a no-op when enabled=True
# The if...else serves to short circuit name resolution for fp8_autocast
# The if...else serves to short circuit name resolution for fp8_autocast
...
@@ -1290,12 +1670,18 @@ class ParallelTransformer(MegatronModule):
...
@@ -1290,12 +1670,18 @@ class ParallelTransformer(MegatronModule):
'encoder_output'
:
encoder_output
,
'encoder_output'
:
encoder_output
,
'enc_dec_attn_mask'
:
enc_dec_attn_mask
,
'enc_dec_attn_mask'
:
enc_dec_attn_mask
,
'inference_params'
:
inference_params
,
'inference_params'
:
inference_params
,
'rotary_pos_emb'
:
rotary_pos_emb
,
}
}
if
self
.
transformer_impl
==
'transformer_engine'
:
if
self
.
transformer_impl
==
'transformer_engine'
:
forward_kwargs
[
'is_first_microbatch'
]
=
is_first_microbatch
forward_kwargs
[
'is_first_microbatch'
]
=
is_first_microbatch
forward_kwargs
[
'checkpoint_core_attention'
]
=
self
.
checkpoint_core_attention
forward_kwargs
[
'checkpoint_core_attention'
]
=
self
.
checkpoint_core_attention
if
self
.
transformer_engine_v_0_10
:
forward_kwargs
[
'rotary_pos_emb'
]
=
rotary_pos_emb
else
:
forward_kwargs
[
'rotary_pos_emb'
]
=
rotary_pos_emb
forward_kwargs
[
'retriever_input'
]
=
retriever_input
forward_kwargs
[
'retriever_output'
]
=
retriever_output
forward_kwargs
[
'retriever_attn_mask'
]
=
retriever_attn_mask
for
index
in
range
(
self
.
num_layers
):
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
layer
=
self
.
_get_layer
(
index
)
...
@@ -1305,6 +1691,14 @@ class ParallelTransformer(MegatronModule):
...
@@ -1305,6 +1691,14 @@ class ParallelTransformer(MegatronModule):
attention_mask
,
attention_mask
,
**
forward_kwargs
)
**
forward_kwargs
)
# First Retro decoder layer returns both hidden_states
# and retriever_output. Make retriever_output available
# to subsequence Retro layers.
if
isinstance
(
hidden_states
,
tuple
):
assert
len
(
hidden_states
)
==
2
hidden_states
,
retriever_output
=
hidden_states
forward_kwargs
[
"retriever_output"
]
=
retriever_output
# Skip counter update for eval and activation checkpointing
# Skip counter update for eval and activation checkpointing
if
torch
.
is_grad_enabled
()
and
self
.
training
:
if
torch
.
is_grad_enabled
()
and
self
.
training
:
self
.
microbatch_count
+=
1
self
.
microbatch_count
+=
1
...
...
megatron/model/vision/classification.py
View file @
051f58f1
...
@@ -13,7 +13,7 @@ from megatron.model.module import MegatronModule
...
@@ -13,7 +13,7 @@ from megatron.model.module import MegatronModule
class
VitClassificationModel
(
MegatronModule
):
class
VitClassificationModel
(
MegatronModule
):
"""Vision Transformer Model."""
"""Vision Transformer Model."""
def
__init__
(
self
,
num_classes
,
finetune
=
False
,
def
__init__
(
self
,
config
,
num_classes
,
finetune
=
False
,
pre_process
=
True
,
post_process
=
True
):
pre_process
=
True
,
post_process
=
True
):
super
(
VitClassificationModel
,
self
).
__init__
()
super
(
VitClassificationModel
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
...
@@ -24,6 +24,7 @@ class VitClassificationModel(MegatronModule):
...
@@ -24,6 +24,7 @@ class VitClassificationModel(MegatronModule):
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
post_process
=
post_process
self
.
backbone
=
VitBackbone
(
self
.
backbone
=
VitBackbone
(
config
=
config
,
pre_process
=
self
.
pre_process
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
post_process
=
self
.
post_process
,
single_token_output
=
True
single_token_output
=
True
...
...
megatron/model/vision/dino.py
View file @
051f58f1
...
@@ -173,11 +173,12 @@ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
...
@@ -173,11 +173,12 @@ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
return
schedule
return
schedule
def
get_student_backbone_and_num_features
(
pre_process
=
True
,
post_process
=
True
):
def
get_student_backbone_and_num_features
(
config
,
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
args
=
get_args
()
if
args
.
vision_backbone_type
==
'vit'
:
if
args
.
vision_backbone_type
==
'vit'
:
student
=
VitBackbone
(
pre_process
=
pre_process
,
student
=
VitBackbone
(
config
,
pre_process
=
pre_process
,
post_process
=
post_process
,
post_process
=
post_process
,
drop_path_rate
=
0.1
,
drop_path_rate
=
0.1
,
single_token_output
=
True
)
single_token_output
=
True
)
...
@@ -194,11 +195,12 @@ def get_student_backbone_and_num_features(pre_process=True, post_process=True):
...
@@ -194,11 +195,12 @@ def get_student_backbone_and_num_features(pre_process=True, post_process=True):
return
student
,
num_features
return
student
,
num_features
def
get_teacher_backbone_and_num_features
(
pre_process
=
True
,
post_process
=
True
):
def
get_teacher_backbone_and_num_features
(
config
,
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
args
=
get_args
()
if
args
.
vision_backbone_type
==
'vit'
:
if
args
.
vision_backbone_type
==
'vit'
:
teacher
=
VitBackbone
(
pre_process
=
pre_process
,
teacher
=
VitBackbone
(
config
,
pre_process
=
pre_process
,
post_process
=
post_process
,
post_process
=
post_process
,
single_token_output
=
True
)
single_token_output
=
True
)
num_features
=
args
.
hidden_size
num_features
=
args
.
hidden_size
...
@@ -215,7 +217,7 @@ def get_teacher_backbone_and_num_features(pre_process=True, post_process=True):
...
@@ -215,7 +217,7 @@ def get_teacher_backbone_and_num_features(pre_process=True, post_process=True):
class
DINOPretrainModel
(
MegatronModule
):
class
DINOPretrainModel
(
MegatronModule
):
def
__init__
(
self
,
pre_process
=
True
,
post_process
=
True
):
def
__init__
(
self
,
config
,
pre_process
=
True
,
post_process
=
True
):
super
(
DINOPretrainModel
,
self
).
__init__
()
super
(
DINOPretrainModel
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
out_dim
=
65536
self
.
out_dim
=
65536
...
@@ -234,7 +236,7 @@ class DINOPretrainModel(MegatronModule):
...
@@ -234,7 +236,7 @@ class DINOPretrainModel(MegatronModule):
self
.
momentum_teacher
=
0.996
self
.
momentum_teacher
=
0.996
student_backbone
,
num_features
=
\
student_backbone
,
num_features
=
\
get_student_backbone_and_num_features
(
pre_process
,
post_process
)
get_student_backbone_and_num_features
(
config
,
pre_process
,
post_process
)
self
.
student
=
MultiCropWrapper
(
self
.
student
=
MultiCropWrapper
(
student_backbone
,
student_backbone
,
...
@@ -249,7 +251,7 @@ class DINOPretrainModel(MegatronModule):
...
@@ -249,7 +251,7 @@ class DINOPretrainModel(MegatronModule):
)
)
teacher_backbone
,
num_features
=
\
teacher_backbone
,
num_features
=
\
get_teacher_backbone_and_num_features
(
pre_process
,
post_process
)
get_teacher_backbone_and_num_features
(
config
,
pre_process
,
post_process
)
self
.
teacher
=
MultiCropWrapper
(
self
.
teacher
=
MultiCropWrapper
(
teacher_backbone
,
teacher_backbone
,
DINOHead
(
num_features
,
self
.
out_dim
)
DINOHead
(
num_features
,
self
.
out_dim
)
...
...
megatron/model/vision/inpainting.py
View file @
051f58f1
...
@@ -18,14 +18,15 @@ from megatron.model.vision.utils import resize_
...
@@ -18,14 +18,15 @@ from megatron.model.vision.utils import resize_
class
VitInpaintingModel
(
MegatronModule
):
class
VitInpaintingModel
(
MegatronModule
):
def
__init__
(
self
,
pre_process
=
True
,
post_process
=
True
):
def
__init__
(
self
,
config
,
pre_process
=
True
,
post_process
=
True
):
super
(
VitInpaintingModel
,
self
).
__init__
()
super
(
VitInpaintingModel
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
post_process
=
post_process
self
.
hidden_size
=
args
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
backbone
=
VitBackbone
(
self
.
backbone
=
VitBackbone
(
config
=
config
,
pre_process
=
self
.
pre_process
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
post_process
=
self
.
post_process
,
class_token
=
False
,
class_token
=
False
,
...
...
megatron/model/vision/mit_backbone.py
View file @
051f58f1
# ---------------------------------------------------------------
# Copyright (c) 2023, NVIDIA Corporation. All rights reserved.
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# found in the LICENSE file in the root directory of this
# source tree.
# ---------------------------------------------------------------
import
math
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
...
megatron/model/vision/vit_backbone.py
View file @
051f58f1
...
@@ -130,24 +130,17 @@ class VitBackbone(MegatronModule):
...
@@ -130,24 +130,17 @@ class VitBackbone(MegatronModule):
"""Vision Transformer Model."""
"""Vision Transformer Model."""
def
__init__
(
self
,
def
__init__
(
self
,
config
,
pre_process
=
True
,
pre_process
=
True
,
post_process
=
True
,
post_process
=
True
,
class_token
=
True
,
class_token
=
True
,
single_token_output
=
False
,
single_token_output
=
False
,
post_layer_norm
=
True
,
post_layer_norm
=
True
,
drop_path_rate
=
0.0
):
drop_path_rate
=
0.0
):
super
(
VitBackbone
,
self
).
__init__
(
share_
word_
embeddings
=
False
)
super
(
VitBackbone
,
self
).
__init__
(
share_embeddings
_and_output_weights
=
False
)
args
=
get_args
()
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
if
args
.
init_method_xavier_uniform
:
self
.
init_method
=
torch
.
nn
.
init
.
xavier_uniform_
self
.
scaled_init_method
=
torch
.
nn
.
init
.
xavier_uniform_
else
:
self
.
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
post_process
=
post_process
...
@@ -202,8 +195,7 @@ class VitBackbone(MegatronModule):
...
@@ -202,8 +195,7 @@ class VitBackbone(MegatronModule):
# Transformer
# Transformer
self
.
transformer
=
ParallelTransformer
(
self
.
transformer
=
ParallelTransformer
(
self
.
init_method
,
config
,
self
.
scaled_init_method
,
pre_process
=
self
.
pre_process
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
post_process
=
self
.
post_process
,
post_layer_norm
=
self
.
post_layer_norm
,
post_layer_norm
=
self
.
post_layer_norm
,
...
...
megatron/optimizer/clip_grads.py
View file @
051f58f1
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
"""Gradient clipping."""
"""Gradient clipping."""
import
torch
import
torch
from
torch
._six
import
inf
from
torch
import
inf
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
import
amp_C
...
...
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment