Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
26f4b5fb
Commit
26f4b5fb
authored
Jul 31, 2024
by
Woosuk Kwon
Browse files
Merge branch 'main' into Dao-AILab/main
parents
5018ac6a
12375706
Pipeline
#2015
failed with stages
in 0 seconds
Changes
95
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
6409 deletions
+0
-6409
flash_attn/models/btlm.py
flash_attn/models/btlm.py
+0
-102
flash_attn/models/falcon.py
flash_attn/models/falcon.py
+0
-143
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+0
-1080
flash_attn/models/gpt_neox.py
flash_attn/models/gpt_neox.py
+0
-124
flash_attn/models/gptj.py
flash_attn/models/gptj.py
+0
-109
flash_attn/models/llama.py
flash_attn/models/llama.py
+0
-422
flash_attn/models/opt.py
flash_attn/models/opt.py
+0
-116
flash_attn/models/vit.py
flash_attn/models/vit.py
+0
-373
flash_attn/modules/__init__.py
flash_attn/modules/__init__.py
+0
-0
flash_attn/modules/block.py
flash_attn/modules/block.py
+0
-397
flash_attn/modules/embedding.py
flash_attn/modules/embedding.py
+0
-216
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+0
-1020
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+0
-191
flash_attn/ops/__init__.py
flash_attn/ops/__init__.py
+0
-0
flash_attn/ops/activations.py
flash_attn/ops/activations.py
+0
-135
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+0
-688
flash_attn/ops/layer_norm.py
flash_attn/ops/layer_norm.py
+0
-800
flash_attn/ops/rms_norm.py
flash_attn/ops/rms_norm.py
+0
-174
flash_attn/ops/triton/__init__.py
flash_attn/ops/triton/__init__.py
+0
-1
flash_attn/ops/triton/cross_entropy.py
flash_attn/ops/triton/cross_entropy.py
+0
-318
No files found.
flash_attn/models/btlm.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2023, Tri Dao.
import
math
import
json
import
re
from
pathlib
import
Path
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
GPT2Config
,
AutoConfig
,
PretrainedConfig
def
remap_state_dict_hf_btlm
(
state_dict
,
config
):
# Word embedding and position embedding
def
key_mapping_pos_emb
(
key
):
return
re
.
sub
(
r
"^transformer.wpe."
,
"transformer.embeddings.position_embeddings."
,
key
)
if
"transformer.wpe.weight"
in
state_dict
:
state_dict
=
OrderedDict
((
key_mapping_pos_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
"transformer.wte.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^transformer.ln_f.(weight|bias)"
,
r
"transformer.ln_f.\1"
,
key
)
key
=
re
.
sub
(
r
"^transformer.h.(\d+).ln_(1|2).(weight|bias)"
,
r
"transformer.layers.\1.norm\2.\3"
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
for
d
in
range
(
config
.
num_hidden_layers
):
W1
=
state_dict
.
pop
(
f
"transformer.h.
{
d
}
.mlp.c_fc.weight"
)
W3
=
state_dict
.
pop
(
f
"transformer.h.
{
d
}
.mlp.c_fc2.weight"
)
state_dict
[
f
"transformer.layers.
{
d
}
.mlp.fc1.weight"
]
=
torch
.
cat
([
W1
.
t
(),
W3
.
t
()],
dim
=
0
)
b1
=
state_dict
.
pop
(
f
"transformer.h.
{
d
}
.mlp.c_fc.bias"
)
b3
=
state_dict
.
pop
(
f
"transformer.h.
{
d
}
.mlp.c_fc2.bias"
)
state_dict
[
f
"transformer.layers.
{
d
}
.mlp.fc1.bias"
]
=
torch
.
cat
([
b1
,
b3
],
dim
=
0
)
W2
=
state_dict
.
pop
(
f
"transformer.h.
{
d
}
.mlp.c_proj.weight"
)
state_dict
[
f
"transformer.layers.
{
d
}
.mlp.fc2.weight"
]
=
W2
.
t
()
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
"^transformer.h.(\d+).mlp.c_proj.bias"
,
r
"transformer.layers.\1.mlp.fc2.bias"
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
d
in
range
(
config
.
num_hidden_layers
):
Wqkv
=
state_dict
.
pop
(
f
"transformer.h.
{
d
}
.attn.c_attn.weight"
)
state_dict
[
f
"transformer.layers.
{
d
}
.mixer.Wqkv.weight"
]
=
Wqkv
.
t
()
Wout
=
state_dict
.
pop
(
f
"transformer.h.
{
d
}
.attn.c_proj.weight"
)
state_dict
[
f
"transformer.layers.
{
d
}
.mixer.out_proj.weight"
]
=
Wout
.
t
()
state_dict
.
pop
(
f
"transformer.relative_pe.slopes"
)
# We don't store the Alibi slopes
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
"^transformer.h.(\d+).attn.c_attn.bias"
,
r
"transformer.layers.\1.mixer.Wqkv.bias"
,
key
)
key
=
re
.
sub
(
r
"^transformer.h.(\d+).attn.c_proj.bias"
,
r
"transformer.layers.\1.mixer.out_proj.bias"
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
btlm_config_to_gpt2_config
(
btlm_config
:
PretrainedConfig
)
->
GPT2Config
:
return
GPT2Config
(
vocab_size
=
btlm_config
.
vocab_size
,
n_positions
=
0
if
btlm_config
.
position_embedding_type
==
"alibi"
else
btlm_config
.
n_positions
,
n_embd
=
btlm_config
.
hidden_size
,
n_layer
=
btlm_config
.
num_hidden_layers
,
n_head
=
btlm_config
.
num_attention_heads
,
n_inner
=
btlm_config
.
n_inner
,
activation_function
=
btlm_config
.
activation_function
,
resid_pdrop
=
btlm_config
.
resid_pdrop
,
embd_pdrop
=
btlm_config
.
embd_pdrop
,
attn_pdrop
=
btlm_config
.
attn_pdrop
,
layer_norm_epsilon
=
btlm_config
.
layer_norm_epsilon
,
initializer_range
=
btlm_config
.
initializer_range
,
bos_token_id
=
btlm_config
.
bos_token_id
,
eos_token_id
=
btlm_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
use_alibi
=
btlm_config
.
position_embedding_type
==
"alibi"
,
use_flash_attn
=
btlm_config
.
position_embedding_type
==
"alibi"
,
# Alibi code path requires flash_attn
mup_width_scale
=
btlm_config
.
mup_width_scale
,
mup_embeddings_multiplier
=
btlm_config
.
mup_embeddings_scale
,
mup_output_multiplier
=
btlm_config
.
mup_output_alpha
,
mup_scale_qk_dot_by_d
=
btlm_config
.
mup_scale_qk_dot_by_d
,
mlp_multiple_of
=
1
,
)
flash_attn/models/falcon.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2023, Tri Dao.
import
math
import
re
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
FalconConfig
,
GPT2Config
def
remap_state_dict_hf_falcon
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
return
re
.
sub
(
r
"^transformer.h."
,
"transformer.layers."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
"^transformer.word_embeddings."
,
"transformer.embeddings.word_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
"transformer.embeddings.word_embeddings.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
if
getattr
(
config
,
"tie_word_embeddings"
):
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
else
:
output_embeddings
=
state_dict
.
pop
(
"lm_head.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
"lm_head.weight"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
output_embeddings_bias
=
state_dict
.
pop
(
"lm_head.bias"
)
state_dict
[
"lm_head.bias"
]
=
F
.
pad
(
output_embeddings_bias
,
(
0
,
vocab_size
-
output_embeddings_bias
.
shape
[
0
])
)
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).input_layernorm."
,
r
"transformer.layers.\1.norm1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).post_attention_layernorm."
,
r
"transformer.layers.\1.norm2."
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).ln_attn."
,
r
"transformer.layers.\1.norm1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).ln_mlp."
,
r
"transformer.layers.\1.norm2."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.dense_h_to_4h."
,
r
"transformer.layers.\1.mlp.fc1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.dense_4h_to_h."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).self_attention.query_key_value."
,
r
"transformer.layers.\1.mixer.Wqkv."
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).self_attention.dense."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
n_head
=
config
.
n_head
n_head_kv
=
getattr
(
config
,
"n_head_kv"
,
1
)
headdim
=
config
.
hidden_size
//
n_head
for
l
in
range
(
config
.
n_layer
):
# The weights are stored in a different layout compared to our implementation
Wqkv
=
rearrange
(
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
),
"(group ratio headdim) ... -> group ratio headdim ..."
,
ratio
=
n_head
//
n_head_kv
+
2
,
headdim
=
headdim
,
)
Wq
=
rearrange
(
Wqkv
[:,
:
-
2
],
"group ratio headdim ... -> (group ratio headdim) ..."
)
Wk
=
rearrange
(
Wqkv
[:,
[
-
2
]],
"group ratio headdim ... -> (group ratio headdim) ..."
)
Wv
=
rearrange
(
Wqkv
[:,
[
-
1
]],
"group ratio headdim ... -> (group ratio headdim) ..."
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
return
state_dict
def
falcon_config_to_gpt2_config
(
falcon_config
:
FalconConfig
)
->
GPT2Config
:
# The 40b config uses "n_head_kv" instead of "num_kv_heads"
n_head_kv
=
getattr
(
falcon_config
,
"n_head_kv"
,
1
if
getattr
(
falcon_config
,
"multi_query"
,
False
)
else
falcon_config
.
n_head
,
)
# HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
# So we have to infer it from the number of heads in the key/value block
parallel_block_tied_norm
=
n_head_kv
==
1
return
GPT2Config
(
vocab_size
=
falcon_config
.
vocab_size
,
n_positions
=
0
,
# No absolute position embedding
n_embd
=
falcon_config
.
hidden_size
,
n_layer
=
falcon_config
.
n_layer
,
n_head
=
falcon_config
.
n_head
,
n_inner
=
falcon_config
.
hidden_size
*
4
,
activation_function
=
"gelu"
,
resid_pdrop
=
falcon_config
.
hidden_dropout
,
embd_pdrop
=
0.0
,
# There doesn't seem to be any embedding dropout
attn_pdrop
=
falcon_config
.
attention_dropout
,
layer_norm_epsilon
=
falcon_config
.
layer_norm_epsilon
,
initializer_range
=
falcon_config
.
initializer_range
,
bos_token_id
=
falcon_config
.
bos_token_id
,
eos_token_id
=
falcon_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
parallel_block
=
falcon_config
.
parallel_attn
,
n_head_kv
=
n_head_kv
,
parallel_block_tied_norm
=
parallel_block_tied_norm
,
rotary_emb_fraction
=
1.0
,
rotary_emb_interleaved
=
False
,
tie_word_embeddings
=
True
,
qkv_proj_bias
=
falcon_config
.
bias
,
out_proj_bias
=
falcon_config
.
bias
,
mlp_fc1_bias
=
falcon_config
.
bias
,
mlp_fc2_bias
=
falcon_config
.
bias
,
lm_head_bias
=
False
,
)
flash_attn/models/gpt.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2024, Tri Dao.
import
logging
import
math
import
re
from
collections
import
OrderedDict
,
namedtuple
from
collections.abc
import
Sequence
from
functools
import
partial
from
typing
import
Dict
,
List
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
GPT2Config
from
flash_attn.models.bigcode
import
remap_state_dict_hf_bigcode
from
flash_attn.models.falcon
import
remap_state_dict_hf_falcon
from
flash_attn.models.gpt_neox
import
remap_state_dict_hf_gpt_neox
from
flash_attn.models.gptj
import
remap_state_dict_hf_gptj
from
flash_attn.models.llama
import
remap_state_dict_hf_llama
from
flash_attn.models.opt
import
remap_state_dict_hf_opt
from
flash_attn.modules.block
import
Block
,
ParallelBlock
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
from
flash_attn.modules.mlp
import
(
FusedMLP
,
GatedMlp
,
Mlp
,
ParallelFusedMLP
,
ParallelGatedMlp
,
ParallelMLP
,
)
from
flash_attn.ops.activations
import
sqrelu_fwd
from
flash_attn.utils.distributed
import
(
all_gather
,
all_gather_raw
,
get_dim_for_local_rank
,
sync_shared_params
,
)
from
flash_attn.utils.generation
import
GenerationMixin
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
except
ImportError
:
ColumnParallelLinear
=
None
try
:
from
flash_attn.ops.triton.mlp
import
FusedDenseSqreluDense
except
ImportError
:
FusedDenseSqreluDense
=
None
try
:
from
flash_attn.ops.triton.layer_norm
import
layer_norm_fn
,
RMSNorm
except
ImportError
:
layer_norm_fn
,
RMSNorm
=
None
,
None
logger
=
logging
.
getLogger
(
__name__
)
def
create_mixer_cls
(
config
,
layer_idx
=
None
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
attn_scale_power
=
0.5
if
not
getattr
(
config
,
"mup_scale_qk_dot_by_d"
,
False
)
else
1.0
softmax_scale
=
1.0
if
not
config
.
scale_attn_weights
else
(
head_dim
**
(
-
attn_scale_power
))
softmax_scale
*=
getattr
(
config
,
"mup_attn_multiplier"
,
1.0
)
if
config
.
scale_attn_by_inverse_layer_idx
:
assert
layer_idx
is
not
None
softmax_scale
/=
float
(
layer_idx
+
1
)
dwconv
=
getattr
(
config
,
"attn_dwconv"
,
False
)
if
dwconv
:
assert
process_group
is
None
,
"TensorParallel MHA does not support dwconv yet"
qkv_proj_bias
=
getattr
(
config
,
"qkv_proj_bias"
,
True
)
out_proj_bias
=
getattr
(
config
,
"out_proj_bias"
,
True
)
rotary_emb_dim
=
int
(
getattr
(
config
,
"rotary_emb_fraction"
,
0.0
)
*
head_dim
)
rotary_emb_base
=
getattr
(
config
,
"rotary_emb_base"
,
10000.0
)
rotary_emb_scale_base
=
getattr
(
config
,
"rotary_emb_scale_base"
,
None
)
rotary_emb_interleaved
=
getattr
(
config
,
"rotary_emb_interleaved"
,
False
)
use_alibi
=
getattr
(
config
,
"use_alibi"
,
False
)
window_size
=
getattr
(
config
,
"window_size"
,
(
-
1
,
-
1
))
use_flash_attn
=
getattr
(
config
,
"use_flash_attn"
,
False
)
fused_bias_fc
=
getattr
(
config
,
"fused_bias_fc"
,
False
)
if
not
fused_bias_fc
:
assert
process_group
is
None
,
"TensorParallel MHA requires fused_bias_fc"
mha_cls
=
MHA
if
process_group
is
None
else
ParallelMHA
serial_kwargs
=
(
{
"fused_bias_fc"
:
fused_bias_fc
,
"dwconv"
:
dwconv
}
if
process_group
is
None
else
{}
)
parallel_kwargs
=
(
{
"process_group"
:
process_group
,
"sequence_parallel"
:
getattr
(
config
,
"sequence_parallel"
,
True
),
}
if
process_group
is
not
None
else
{}
)
num_heads_kv
=
getattr
(
config
,
"n_head_kv"
,
None
)
mixer_cls
=
partial
(
mha_cls
,
num_heads
=
config
.
num_attention_heads
,
num_heads_kv
=
num_heads_kv
,
qkv_proj_bias
=
qkv_proj_bias
,
out_proj_bias
=
out_proj_bias
,
dropout
=
config
.
attn_pdrop
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
layer_idx
=
layer_idx
,
rotary_emb_dim
=
rotary_emb_dim
,
rotary_emb_base
=
rotary_emb_base
,
rotary_emb_scale_base
=
rotary_emb_scale_base
,
rotary_emb_interleaved
=
rotary_emb_interleaved
,
use_alibi
=
use_alibi
,
window_size
=
window_size
,
use_flash_attn
=
use_flash_attn
,
**
serial_kwargs
,
**
parallel_kwargs
,
**
factory_kwargs
,
)
return
mixer_cls
def
create_mlp_cls
(
config
,
layer_idx
=
None
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
mlp_fc1_bias
=
getattr
(
config
,
"mlp_fc1_bias"
,
True
)
mlp_fc2_bias
=
getattr
(
config
,
"mlp_fc2_bias"
,
True
)
fused_mlp
=
getattr
(
config
,
"fused_mlp"
,
False
)
if
fused_mlp
:
assert
config
.
activation_function
in
[
"gelu_new"
,
"gelu_fast"
,
"gelu_approx"
,
"gelu_pytorch_tanh"
,
"relu"
,
"sqrelu"
,
]
fused_dense_sqrelu_dense
=
getattr
(
config
,
"fused_dense_sqrelu_dense"
,
False
)
if
fused_dense_sqrelu_dense
:
assert
config
.
activation_function
==
"sqrelu"
,
(
"fused_dense_sqrelu_dense only "
"supports approximate activation_function sqrelu"
)
assert
not
(
fused_dense_sqrelu_dense
and
fused_mlp
)
if
not
fused_mlp
and
not
fused_dense_sqrelu_dense
:
assert
config
.
activation_function
in
[
"gelu"
,
"gelu_new"
,
"gelu_fast"
,
"gelu_approx"
,
"gelu_pytorch_tanh"
,
"relu"
,
"sqrelu"
,
"glu"
,
"swiglu"
,
"geglu"
,
]
if
config
.
activation_function
in
[
"glu"
,
"swiglu"
,
"geglu"
]:
activation
=
(
F
.
sigmoid
if
config
.
activation_function
==
"glu"
else
(
F
.
silu
if
config
.
activation_function
==
"swiglu"
else
F
.
gelu
)
)
mlp_cls
=
GatedMlp
if
process_group
is
None
else
ParallelGatedMlp
parallel_kwargs
=
(
{
"process_group"
:
process_group
,
"sequence_parallel"
:
getattr
(
config
,
"sequence_parallel"
,
True
),
}
if
process_group
is
not
None
else
{}
)
mlp_multiple_of
=
getattr
(
config
,
"mlp_multiple_of"
,
128
)
mlp_cls
=
partial
(
mlp_cls
,
hidden_features
=
config
.
n_inner
,
activation
=
activation
,
bias1
=
mlp_fc1_bias
,
bias2
=
mlp_fc2_bias
,
multiple_of
=
mlp_multiple_of
,
**
parallel_kwargs
,
**
factory_kwargs
,
)
else
:
if
config
.
activation_function
==
"relu"
:
activation
=
partial
(
F
.
relu
,
inplace
=
True
)
elif
config
.
activation_function
==
"sqrelu"
:
activation
=
sqrelu_fwd
else
:
approximate
=
(
"tanh"
if
config
.
activation_function
in
[
"gelu_new"
,
"gelu_fast"
,
"gelu_approx"
,
"gelu_pytorch_tanh"
]
else
"none"
)
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
)
mlp_cls
=
Mlp
if
process_group
is
None
else
ParallelMLP
parallel_kwargs
=
(
{
"process_group"
:
process_group
,
"sequence_parallel"
:
getattr
(
config
,
"sequence_parallel"
,
True
),
}
if
process_group
is
not
None
else
{}
)
mlp_cls
=
partial
(
mlp_cls
,
hidden_features
=
config
.
n_inner
,
activation
=
activation
,
bias1
=
mlp_fc1_bias
,
bias2
=
mlp_fc2_bias
,
**
parallel_kwargs
,
**
factory_kwargs
,
)
else
:
mlp_checkpoint_lvl
=
getattr
(
config
,
"mlp_checkpoint_lvl"
,
0
)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if
isinstance
(
mlp_checkpoint_lvl
,
Sequence
):
assert
layer_idx
is
not
None
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
if
fused_mlp
:
if
FusedMLP
is
None
:
raise
ImportError
(
"fused_dense is not installed"
)
activation
=
(
"gelu_approx"
if
config
.
activation_function
in
[
"gelu_new"
,
"gelu_fast"
,
"gelu_approx"
,
"gelu_pytorch_tanh"
]
else
config
.
activation_function
)
mlp_cls
=
FusedMLP
if
process_group
is
None
else
ParallelFusedMLP
parallel_kwargs
=
(
{
"process_group"
:
process_group
,
"sequence_parallel"
:
getattr
(
config
,
"sequence_parallel"
,
True
),
}
if
process_group
is
not
None
else
{}
)
mlp_cls
=
partial
(
mlp_cls
,
hidden_features
=
config
.
n_inner
,
activation
=
activation
,
checkpoint_lvl
=
mlp_checkpoint_lvl
,
bias1
=
mlp_fc1_bias
,
bias2
=
mlp_fc2_bias
,
**
parallel_kwargs
,
**
factory_kwargs
,
)
elif
fused_dense_sqrelu_dense
:
if
process_group
is
not
None
:
assert
fused_mlp
,
"Tensor Parallel is not implemented for FusedDenseSqreluDense"
assert
FusedDenseSqreluDense
is
not
None
mlp_cls
=
partial
(
FusedDenseSqreluDense
,
hidden_features
=
config
.
n_inner
,
checkpoint_lvl
=
mlp_checkpoint_lvl
,
**
factory_kwargs
,
)
else
:
raise
RuntimeError
(
"MLP type not supported"
)
return
mlp_cls
def
create_block
(
config
,
layer_idx
=
None
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
sequence_parallel
=
getattr
(
config
,
"sequence_parallel"
,
True
)
mixer_cls
=
create_mixer_cls
(
config
,
layer_idx
,
process_group
=
process_group
,
**
factory_kwargs
)
mlp_cls
=
create_mlp_cls
(
config
,
layer_idx
,
process_group
=
process_group
,
**
factory_kwargs
)
use_rms_norm
=
getattr
(
config
,
"rms_norm"
,
False
)
norm_cls
=
partial
(
nn
.
LayerNorm
if
not
use_rms_norm
else
RMSNorm
,
eps
=
config
.
layer_norm_epsilon
,
**
factory_kwargs
,
)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
residual_in_fp32
=
getattr
(
config
,
"residual_in_fp32"
,
False
)
resid_dropout1
=
config
.
resid_pdrop
if
layer_idx
is
None
or
layer_idx
>
0
else
config
.
embd_pdrop
prenorm
=
getattr
(
config
,
"prenorm"
,
True
)
parallel_block
=
getattr
(
config
,
"parallel_block"
,
False
)
if
not
parallel_block
:
block
=
Block
(
config
.
hidden_size
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_cls
,
prenorm
=
prenorm
,
resid_dropout1
=
resid_dropout1
,
resid_dropout2
=
config
.
resid_pdrop
,
fused_dropout_add_ln
=
getattr
(
config
,
"fused_dropout_add_ln"
,
False
),
residual_in_fp32
=
residual_in_fp32
,
sequence_parallel
=
sequence_parallel
and
process_group
is
not
None
,
mark_shared_params
=
process_group
is
not
None
,
)
else
:
assert
prenorm
block
=
ParallelBlock
(
config
.
hidden_size
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_cls
,
resid_dropout1
=
resid_dropout1
,
resid_dropout2
=
config
.
resid_pdrop
,
tied_norm
=
getattr
(
config
,
"parallel_block_tied_norm"
,
False
),
fused_dropout_add_ln
=
getattr
(
config
,
"fused_dropout_add_ln"
,
False
),
residual_in_fp32
=
residual_in_fp32
,
sequence_parallel
=
sequence_parallel
and
process_group
is
not
None
,
mark_shared_params
=
process_group
is
not
None
,
)
block
.
layer_idx
=
layer_idx
return
block
class
GPTPreTrainedModel
(
nn
.
Module
):
"""An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
().
__init__
()
if
not
isinstance
(
config
,
GPT2Config
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
)
)
self
.
config
=
config
@
classmethod
def
from_pretrained
(
cls
,
model_name
,
config
,
*
args
,
strict
=
True
,
device
=
None
,
dtype
=
None
,
world_size
=
1
,
rank
=
0
,
**
kwargs
,
):
"""
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
"""
# Instantiate model.
model
=
cls
(
config
,
*
args
,
device
=
device
,
dtype
=
dtype
,
**
kwargs
)
# Load state_dict in cpu because we already initialized the model in GPU, and we don't
# want extra stuff taking up more GPU memory
state_dict
=
state_dict_from_pretrained
(
model_name
,
device
=
"cpu"
,
dtype
=
dtype
)
if
model_name
.
startswith
(
"gpt2"
):
state_dict
=
remap_state_dict_hf_gpt2
(
state_dict
,
config
)
elif
model_name
.
startswith
(
"facebook/opt"
):
state_dict
=
remap_state_dict_hf_opt
(
state_dict
,
config
)
elif
model_name
.
startswith
(
"EleutherAI/gpt-j-"
)
or
model_name
.
startswith
(
"togethercomputer/GPT-JT-"
):
state_dict
=
remap_state_dict_hf_gptj
(
state_dict
,
config
)
elif
(
model_name
.
startswith
(
"EleutherAI/gpt-neox-"
)
or
model_name
.
startswith
(
"EleutherAI/pythia-"
)
or
model_name
.
startswith
(
"togethercomputer/RedPajama-INCITE-"
)
):
state_dict
=
remap_state_dict_hf_gpt_neox
(
state_dict
,
config
)
elif
model_name
.
startswith
(
"tiiuae/falcon-"
):
state_dict
=
remap_state_dict_hf_falcon
(
state_dict
,
config
)
elif
model_name
.
startswith
(
"meta-llama/Llama-"
):
state_dict
=
remap_state_dict_hf_llama
(
state_dict
,
config
)
elif
model_name
.
startswith
(
"bigcode/"
)
or
model_name
.
startswith
(
"WizardLM/"
):
state_dict
=
remap_state_dict_hf_bigcode
(
state_dict
,
config
)
else
:
raise
NotImplementedError
(
f
"Model
{
model_name
}
not supported"
)
if
world_size
>
1
:
state_dict
=
shard_state_dict_tp
(
state_dict
,
config
,
world_size
,
rank
)
load_return
=
model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
logger
.
info
(
load_return
)
return
model
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def
_init_weights
(
module
,
n_layer
,
initializer_range
=
0.02
,
mup_width_scale
=
1.0
,
rescale_prenorm_residual
=
True
):
mup_init_scale
=
math
.
sqrt
(
mup_width_scale
)
if
isinstance
(
module
,
nn
.
Linear
):
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
*
mup_init_scale
)
optim_cfg
=
getattr
(
module
.
weight
,
"_optim"
,
{})
optim_cfg
.
update
({
"lr_multiplier"
:
mup_width_scale
})
setattr
(
module
.
weight
,
"_optim"
,
optim_cfg
)
if
module
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
module
.
bias
)
elif
isinstance
(
module
,
nn
.
Embedding
):
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
)
if
rescale_prenorm_residual
:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for
name
,
p
in
module
.
named_parameters
():
if
name
in
[
"out_proj.weight"
,
"fc2.weight"
]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
nn
.
init
.
normal_
(
p
,
mean
=
0.0
,
std
=
initializer_range
*
mup_init_scale
/
math
.
sqrt
(
2
*
n_layer
)
)
class
GPTModel
(
GPTPreTrainedModel
):
def
__init__
(
self
,
config
:
GPT2Config
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
super
().
__init__
(
config
)
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
self
.
process_group
=
process_group
self
.
sequence_parallel
=
getattr
(
config
,
"sequence_parallel"
,
True
)
assert
config
.
activation_function
in
[
"gelu"
,
"gelu_new"
,
"gelu_fast"
,
"gelu_approx"
,
"gelu_pytorch_tanh"
,
"relu"
,
"sqrelu"
,
"glu"
,
"swiglu"
,
"geglu"
,
]
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
self
.
embeddings_multiplier
=
getattr
(
config
,
"mup_embeddings_multiplier"
,
1.0
)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
self
.
residual_in_fp32
=
getattr
(
config
,
"residual_in_fp32"
,
False
)
# These 2 options are for OPT-350m
self
.
prenorm
=
getattr
(
config
,
"prenorm"
,
True
)
use_rms_norm
=
getattr
(
config
,
"rms_norm"
,
False
)
word_embed_proj_dim
=
getattr
(
config
,
"word_embed_proj_dim"
,
None
)
# For GPT-J, GPT-NeoX
self
.
parallel_block
=
getattr
(
config
,
"parallel_block"
,
False
)
if
process_group
is
None
:
self
.
embeddings
=
GPT2Embeddings
(
config
.
hidden_size
,
vocab_size
,
config
.
max_position_embeddings
,
word_embed_proj_dim
=
word_embed_proj_dim
,
**
factory_kwargs
,
)
else
:
self
.
embeddings
=
ParallelGPT2Embeddings
(
config
.
hidden_size
,
vocab_size
,
config
.
max_position_embeddings
,
process_group
=
process_group
,
sequence_parallel
=
self
.
sequence_parallel
,
**
factory_kwargs
,
)
# We change the order of dropout, residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
# nn.Dropout probabilities are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
self
.
layers
=
nn
.
ModuleList
(
[
create_block
(
config
,
layer_idx
=
i
,
process_group
=
process_group
,
**
factory_kwargs
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
)
rotary_emb_fraction
=
getattr
(
config
,
"rotary_emb_fraction"
,
0.0
)
if
rotary_emb_fraction
>
0.0
:
# Tie all the RotaryEmbedding modules to share the same cos/sin cache
for
layer
in
self
.
layers
[
1
:]:
layer
.
mixer
.
rotary_emb
=
self
.
layers
[
0
].
mixer
.
rotary_emb
self
.
fused_dropout_add_ln
=
getattr
(
config
,
"fused_dropout_add_ln"
,
False
)
if
self
.
fused_dropout_add_ln
:
if
layer_norm_fn
is
None
:
raise
ImportError
(
"Triton is not installed"
)
if
self
.
prenorm
:
self
.
drop_f
=
nn
.
Dropout
(
config
.
resid_pdrop
)
norm_cls
=
nn
.
LayerNorm
if
not
use_rms_norm
else
RMSNorm
self
.
ln_f
=
norm_cls
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
,
**
factory_kwargs
)
if
process_group
is
not
None
:
for
p
in
self
.
ln_f
.
parameters
():
# Mark the norm parameters as "shared_params" so that we sync their values at init.
p
.
_shared_params
=
True
# Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
if
self
.
sequence_parallel
:
p
.
_sequence_parallel
=
True
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
config
.
num_hidden_layers
,
initializer_range
=
config
.
initializer_range
,
mup_width_scale
=
getattr
(
config
,
"mup_width_scale"
,
1.0
),
)
)
self
.
tie_weights
()
def
tie_weights
(
self
):
if
self
.
process_group
is
not
None
:
sync_shared_params
(
self
,
self
.
process_group
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
{
i
:
layer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
for
i
,
layer
in
enumerate
(
self
.
layers
)
}
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
):
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# dimensions so that we can split on it easily, in case of small batch size.
# Only the attention layers need to know the seqlen.
embedding_kwargs
=
(
{
"combine_batch_seqlen_dim"
:
True
}
if
self
.
process_group
is
not
None
and
self
.
sequence_parallel
else
{}
)
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
**
embedding_kwargs
)
if
self
.
embeddings_multiplier
!=
1.0
:
hidden_states
=
hidden_states
*
self
.
embeddings_multiplier
if
self
.
parallel_block
:
hidden_states2
=
None
residual
=
None
mixer_kwargs
=
(
{
"seqlen"
:
input_ids
.
shape
[
1
]}
if
self
.
process_group
is
not
None
and
self
.
sequence_parallel
else
{}
)
if
inference_params
is
not
None
:
mixer_kwargs
[
"inference_params"
]
=
inference_params
for
layer
in
self
.
layers
:
if
self
.
prenorm
:
if
not
self
.
parallel_block
:
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
mixer_kwargs
=
mixer_kwargs
)
else
:
hidden_states
,
hidden_states2
,
residual
=
layer
(
hidden_states
,
hidden_states2
,
residual
,
mixer_kwargs
=
mixer_kwargs
)
else
:
hidden_states
=
layer
(
hidden_states
,
mixer_kwargs
=
mixer_kwargs
)
if
self
.
prenorm
:
if
not
self
.
fused_dropout_add_ln
:
dropped
=
self
.
drop_f
(
hidden_states
)
if
not
self
.
parallel_block
:
residual
=
(
dropped
+
residual
)
if
residual
is
not
None
else
dropped
else
:
dropped2
=
self
.
drop_f
(
hidden_states2
)
residual
=
(
(
residual
+
dropped
+
dropped2
)
if
residual
is
not
None
else
dropped
+
dropped2
)
hidden_states
=
self
.
ln_f
(
residual
.
to
(
dtype
=
self
.
ln_f
.
weight
.
dtype
))
else
:
# Set prenorm=False here since we don't need the residual
hidden_states
=
layer_norm_fn
(
hidden_states
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
residual
=
residual
,
x1
=
None
if
not
self
.
parallel_block
else
hidden_states2
,
eps
=
self
.
ln_f
.
eps
,
dropout_p
=
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
prenorm
=
False
,
is_rms_norm
=
isinstance
(
self
.
ln_f
,
RMSNorm
)
)
return
hidden_states
class
GPTLMHeadModel
(
GPTPreTrainedModel
,
GenerationMixin
):
def
__init__
(
self
,
config
:
GPT2Config
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
(
config
)
self
.
process_group
=
process_group
self
.
transformer
=
GPTModel
(
config
,
process_group
=
process_group
,
**
factory_kwargs
)
self
.
tie_word_embeddings
=
getattr
(
config
,
"tie_word_embeddings"
,
True
)
lm_head_bias
=
getattr
(
config
,
"lm_head_bias"
,
False
)
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
# This option is for OPT-350m
word_embed_proj_dim
=
getattr
(
config
,
"word_embed_proj_dim"
,
None
)
embed_dim
=
config
.
n_embd
if
word_embed_proj_dim
is
None
else
word_embed_proj_dim
if
word_embed_proj_dim
is
not
None
:
self
.
project_out
=
nn
.
Linear
(
config
.
n_embd
,
embed_dim
,
bias
=
False
,
**
factory_kwargs
)
else
:
self
.
project_out
=
None
mup_width_scale
=
getattr
(
config
,
"mup_width_scale"
,
1.0
)
mup_output_multiplier
=
getattr
(
config
,
"mup_output_multiplier"
,
1.0
)
self
.
output_scale
=
mup_output_multiplier
*
mup_width_scale
if
process_group
is
None
:
self
.
lm_head
=
nn
.
Linear
(
embed_dim
,
vocab_size
,
bias
=
lm_head_bias
,
**
factory_kwargs
)
else
:
if
ColumnParallelLinear
is
None
:
raise
ImportError
(
"fused_dense_lib is not installed"
)
self
.
lm_head
=
ColumnParallelLinear
(
embed_dim
,
vocab_size
,
process_group
,
bias
=
lm_head_bias
,
sequence_parallel
=
getattr
(
config
,
"sequence_parallel"
,
True
),
**
factory_kwargs
,
)
self
.
norm_head
=
getattr
(
config
,
"norm_head"
,
False
)
# Initialize weights and apply final processing
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
config
.
num_hidden_layers
,
initializer_range
=
config
.
initializer_range
,
mup_width_scale
=
mup_width_scale
,
)
)
self
.
tie_weights
()
def
tie_weights
(
self
):
if
self
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
transformer
.
embeddings
.
word_embeddings
.
weight
if
self
.
process_group
is
not
None
:
sync_shared_params
(
self
,
self
.
process_group
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
self
.
transformer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
,
num_last_tokens
=
0
):
"""
input_ids: (batch, seqlen) int tensor
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
assert
(
input_ids
.
ndim
==
2
),
f
"Expected `input_ids` to have shape [b, slen], but got shape
{
input_ids
.
shape
}
"
b
,
slen
=
input_ids
.
shape
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
)
if
inference_params
is
not
None
:
assert
hidden_states
.
ndim
==
3
,
"sequence_parallel is not supported in generation mode"
if
num_last_tokens
>
0
:
hidden_states
=
hidden_states
[:,
-
num_last_tokens
:]
if
self
.
project_out
is
not
None
:
hidden_states
=
self
.
project_out
(
hidden_states
)
if
self
.
output_scale
!=
1.0
:
hidden_states
=
hidden_states
*
self
.
output_scale
if
not
self
.
norm_head
:
lm_logits
=
self
.
lm_head
(
hidden_states
)
else
:
lm_head_weight
=
F
.
normalize
(
self
.
lm_head
.
weight
)
if
isinstance
(
self
.
lm_head
,
ColumnParallelLinear
)
and
self
.
lm_head
.
sequence_parallel
:
hidden_states
=
all_gather
(
hidden_states
,
self
.
lm_head
.
process_group
)
lm_logits
=
F
.
linear
(
hidden_states
,
lm_head_weight
,
bias
=
self
.
lm_head
.
bias
)
# During inference, we want the full logit for sampling
if
isinstance
(
self
.
lm_head
,
ColumnParallelLinear
)
and
inference_params
is
not
None
:
lm_logits
,
_
=
all_gather_raw
(
lm_logits
,
self
.
lm_head
.
process_group
)
lm_logits
=
rearrange
(
lm_logits
,
"(n b) ... d -> b ... (n d)"
,
b
=
b
)
CausalLMOutput
=
namedtuple
(
"CausalLMOutput"
,
[
"logits"
])
return
CausalLMOutput
(
logits
=
lm_logits
)
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
# Remapping from our checkpoints that used a different ordering of layers in the block
# Previous: Attn / MLP -> Dropout -> Add -> LN
# Current: Dropout -> Add -> LN -> Attn / MLP
if
"transformer.ln_0.weight"
in
state_dict
:
n_layers
=
len
(
self
.
transformer
.
layers
)
ln_weight
=
state_dict
.
pop
(
f
"transformer.layers.
{
n_layers
-
1
}
.norm2.weight"
)
ln_bias
=
state_dict
.
pop
(
f
"transformer.layers.
{
n_layers
-
1
}
.norm2.bias"
)
state_dict
[
"transformer.ln_f.weight"
]
=
ln_weight
state_dict
[
"transformer.ln_f.bias"
]
=
ln_bias
for
l
in
reversed
(
range
(
n_layers
)):
ln_weight
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.norm1.weight"
)
ln_bias
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.norm1.bias"
)
state_dict
[
f
"transformer.layers.
{
l
}
.norm2.weight"
]
=
ln_weight
state_dict
[
f
"transformer.layers.
{
l
}
.norm2.bias"
]
=
ln_bias
if
l
>
0
:
ln_weight
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
-
1
}
.norm2.weight"
)
ln_bias
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
-
1
}
.norm2.bias"
)
state_dict
[
f
"transformer.layers.
{
l
}
.norm1.weight"
]
=
ln_weight
state_dict
[
f
"transformer.layers.
{
l
}
.norm1.bias"
]
=
ln_bias
ln_weight
=
state_dict
.
pop
(
"transformer.ln_0.weight"
)
ln_bias
=
state_dict
.
pop
(
"transformer.ln_0.bias"
)
state_dict
[
f
"transformer.layers.0.norm1.weight"
]
=
ln_weight
state_dict
[
f
"transformer.layers.0.norm1.bias"
]
=
ln_bias
return
super
().
load_state_dict
(
state_dict
,
strict
=
strict
)
def
shard_state_dict_tp
(
state_dict
,
config
,
world_size
,
rank
):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
This function modifies state_dict in place.
"""
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
assert
vocab_size
%
world_size
==
0
assert
config
.
hidden_size
%
world_size
==
0
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
assert
inner_dim
%
world_size
==
0
n_head
=
config
.
n_head
n_head_kv
=
getattr
(
config
,
"n_head_kv"
,
n_head
)
embed_dim
=
config
.
hidden_size
head_dim
=
embed_dim
//
n_head
def
shard_first_dim
(
state_dict
,
key
):
if
key
in
state_dict
:
x
=
state_dict
[
key
]
dim
=
x
.
shape
[
0
]
//
world_size
state_dict
[
key
]
=
x
[
rank
*
dim
:
(
rank
+
1
)
*
dim
]
def
shard_last_dim
(
state_dict
,
key
,
multiple_of
=
1
):
if
key
in
state_dict
:
x
=
state_dict
[
key
]
dim_each_rank
=
[
get_dim_for_local_rank
(
x
.
size
(
-
1
),
world_size
,
local_rank
,
multiple_of
)
for
local_rank
in
range
(
world_size
)
]
beg
,
end
=
tuple
(
sum
(
dim_each_rank
[:
pos
])
for
pos
in
(
rank
,
rank
+
1
))
state_dict
[
key
]
=
x
[...,
beg
:
end
]
def
shard_gatedmlp_fc1_dim
(
state_dict
,
key
):
if
key
in
state_dict
:
x
=
state_dict
[
key
]
dim
=
x
.
shape
[
0
]
//
world_size
//
2
state_dict
[
key
]
=
rearrange
(
rearrange
(
x
,
"(two o) ... -> two o ..."
,
two
=
2
)[:,
rank
*
dim
:
(
rank
+
1
)
*
dim
],
"two o ... -> (two o) ..."
,
)
def
shard_qkv_headdim
(
state_dict
,
key
):
if
key
in
state_dict
:
n_head_each_rank
=
[
get_dim_for_local_rank
(
n_head
,
world_size
,
local_rank
)
for
local_rank
in
range
(
world_size
)
]
n_head_kv_each_rank
=
[
get_dim_for_local_rank
(
n_head_kv
,
world_size
,
local_rank
)
for
local_rank
in
range
(
world_size
)
]
beg_n_head
=
sum
(
n_head_each_rank
[:
rank
])
end_n_head
=
sum
(
n_head_each_rank
[:
rank
+
1
])
beg_n_head_kv
=
sum
(
n_head_kv_each_rank
[:
rank
])
end_n_head_kv
=
sum
(
n_head_kv_each_rank
[:
rank
+
1
])
if
n_head_kv
==
n_head
:
x
=
rearrange
(
state_dict
[
key
],
"(three d) ... -> three d ..."
,
three
=
3
)
state_dict
[
key
]
=
rearrange
(
x
[:,
beg_n_head
*
head_dim
:
end_n_head
*
head_dim
],
"three d ... -> (three d) ..."
,
)
else
:
x
=
rearrange
(
state_dict
[
key
],
"(nheadqkv headdim) ... -> nheadqkv headdim ..."
,
nheadqkv
=
n_head
+
2
*
n_head_kv
,
)
state_dict
[
key
]
=
rearrange
(
torch
.
cat
(
[
x
[
beg_n_head
:
end_n_head
],
x
[
n_head
+
beg_n_head_kv
:
n_head
+
end_n_head_kv
],
x
[
n_head
+
n_head_kv
+
beg_n_head_kv
:
n_head
+
n_head_kv
+
end_n_head_kv
],
],
dim
=
0
,
),
"nheadqkv headdim ... -> (nheadqkv headdim) ..."
,
)
shard_first_dim
(
state_dict
,
"transformer.embeddings.word_embeddings.weight"
)
if
"lm_head.weight"
in
state_dict
:
shard_first_dim
(
state_dict
,
"lm_head.weight"
)
if
"transformer.embeddings.position_embeddings.weight"
in
state_dict
:
shard_last_dim
(
state_dict
,
"transformer.embeddings.position_embeddings.weight"
)
for
i
in
range
(
config
.
num_hidden_layers
):
shard_qkv_headdim
(
state_dict
,
f
"transformer.layers.
{
i
}
.mixer.Wqkv.weight"
)
shard_qkv_headdim
(
state_dict
,
f
"transformer.layers.
{
i
}
.mixer.Wqkv.bias"
)
shard_last_dim
(
state_dict
,
f
"transformer.layers.
{
i
}
.mixer.out_proj.weight"
,
multiple_of
=
head_dim
)
if
rank
!=
0
:
state_dict
.
pop
(
f
"transformer.layers.
{
i
}
.mixer.out_proj.bias"
,
None
)
if
config
.
activation_function
in
[
"glu"
,
"swiglu"
,
"geglu"
]:
shard_gatedmlp_fc1_dim
(
state_dict
,
f
"transformer.layers.
{
i
}
.mlp.fc1.weight"
)
shard_gatedmlp_fc1_dim
(
state_dict
,
f
"transformer.layers.
{
i
}
.mlp.fc1.bias"
)
else
:
shard_first_dim
(
state_dict
,
f
"transformer.layers.
{
i
}
.mlp.fc1.weight"
)
shard_first_dim
(
state_dict
,
f
"transformer.layers.
{
i
}
.mlp.fc1.bias"
)
shard_last_dim
(
state_dict
,
f
"transformer.layers.
{
i
}
.mlp.fc2.weight"
)
if
rank
!=
0
:
state_dict
.
pop
(
f
"transformer.layers.
{
i
}
.mlp.fc2.bias"
,
None
)
return
state_dict
def
combine_state_dicts_tp
(
state_dicts
:
List
[
Dict
[
str
,
torch
.
Tensor
]],
config
:
GPT2Config
):
"""Convert the list of sharded state_dict of a GPT model with tensor parallel to
the state_dict of a standard GPT model.
This function is meant to be the "reverse" of shard_state_dict_tp.
Precondition:
- state_dicts should be ordered in the same way as the shards were created.
"""
world_size
=
len
(
state_dicts
)
keys
=
state_dicts
[
0
].
keys
()
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
assert
vocab_size
%
world_size
==
0
assert
config
.
hidden_size
%
world_size
==
0
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
assert
inner_dim
%
world_size
==
0
assert
config
.
hidden_size
%
config
.
n_head
==
0
headdim
=
config
.
hidden_size
//
config
.
n_head
# Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
# vocab_size // world_size coordinates are nonzero.
def
combine_word_embeddings
(
state_dicts
,
state_dict
,
key
):
dim
=
0
if
state_dicts
[
0
][
key
].
shape
[
0
]
==
vocab_size
//
world_size
else
1
state_dict
[
key
]
=
torch
.
cat
([
s
[
key
]
for
s
in
state_dicts
],
dim
=
dim
)
def
combine_dim
(
state_dicts
,
state_dict
,
key
,
dim
=-
1
):
if
key
in
state_dict
:
state_dict
[
key
]
=
torch
.
cat
([
s
[
key
]
for
s
in
state_dicts
],
dim
=
dim
)
def
combine_qkv_headdim
(
state_dicts
,
state_dict
,
key
):
n_head
=
config
.
n_head
n_head_kv
=
getattr
(
config
,
"n_head_kv"
,
n_head
)
if
key
in
state_dict
:
if
n_head_kv
==
n_head
:
xs
=
[
rearrange
(
s
[
key
],
"(three d) ... -> three d ..."
,
three
=
3
)
for
s
in
state_dicts
]
state_dict
[
key
]
=
rearrange
(
torch
.
cat
(
xs
,
dim
=
1
),
"three d ... -> (three d) ..."
)
else
:
n_head_each_rank
=
[
get_dim_for_local_rank
(
n_head
,
world_size
,
local_rank
)
for
local_rank
in
range
(
world_size
)
]
n_head_kv_each_rank
=
[
get_dim_for_local_rank
(
n_head_kv
,
world_size
,
local_rank
)
for
local_rank
in
range
(
world_size
)
]
xs
=
[
rearrange
(
s
[
key
],
"(nheadqkv headdim) ... -> nheadqkv headdim ..."
,
nheadqkv
=
rank_n_head
+
2
*
rank_n_head_kv
,
headdim
=
headdim
,
)
for
s
,
rank_n_head
,
rank_n_head_kv
in
zip
(
state_dicts
,
n_head_each_rank
,
n_head_kv_each_rank
)
]
wq
=
torch
.
cat
([
x
[:
n_head_each_rank
[
rank
]]
for
rank
,
x
in
enumerate
(
xs
)],
dim
=
0
)
wk
=
torch
.
cat
(
[
x
[
n_head_each_rank
[
rank
]
:
n_head_each_rank
[
rank
]
+
n_head_kv_each_rank
[
rank
]
]
for
rank
,
x
in
enumerate
(
xs
)
],
dim
=
0
,
)
wv
=
torch
.
cat
(
[
x
[
n_head_each_rank
[
rank
]
+
n_head_kv_each_rank
[
rank
]
:]
for
rank
,
x
in
enumerate
(
xs
)
],
dim
=
0
,
)
wqkv
=
torch
.
cat
(
[
wq
,
wk
,
wv
],
dim
=
0
,
)
state_dict
[
key
]
=
rearrange
(
wqkv
,
"nheadqkv headdim ... -> (nheadqkv headdim) ..."
,
)
def
combine_gated_mlp
(
state_dicts
,
state_dict
,
key
):
if
key
in
state_dict
:
xs
=
[
rearrange
(
s
[
key
],
"(two d) ... -> two d ..."
,
two
=
2
)
for
s
in
state_dicts
]
state_dict
[
key
]
=
rearrange
(
torch
.
cat
(
xs
,
dim
=
1
),
"two d ... -> (two d) ..."
)
state_dict
=
state_dicts
[
0
].
copy
()
# don't modify state_dict[0] inplace
combine_word_embeddings
(
state_dicts
,
state_dict
,
"transformer.embeddings.word_embeddings.weight"
)
if
"lm_head.weight"
in
state_dict
:
combine_word_embeddings
(
state_dicts
,
state_dict
,
"lm_head.weight"
)
if
"transformer.embeddings.position_embeddings.weight"
in
state_dict
:
combine_dim
(
state_dicts
,
state_dict
,
"transformer.embeddings.position_embeddings.weight"
,
-
1
)
mlp_combine_fn
=
(
combine_gated_mlp
if
config
.
activation_function
in
[
"glu"
,
"swiglu"
,
"geglu"
]
else
partial
(
combine_dim
,
dim
=
0
)
)
for
i
in
range
(
config
.
num_hidden_layers
):
combine_qkv_headdim
(
state_dicts
,
state_dict
,
f
"transformer.layers.
{
i
}
.mixer.Wqkv.weight"
)
combine_qkv_headdim
(
state_dicts
,
state_dict
,
f
"transformer.layers.
{
i
}
.mixer.Wqkv.bias"
)
combine_dim
(
state_dicts
,
state_dict
,
f
"transformer.layers.
{
i
}
.mixer.out_proj.weight"
,
-
1
)
mlp_combine_fn
(
state_dicts
,
state_dict
,
f
"transformer.layers.
{
i
}
.mlp.fc1.weight"
)
combine_dim
(
state_dicts
,
state_dict
,
f
"transformer.layers.
{
i
}
.mlp.fc1.bias"
,
0
)
combine_dim
(
state_dicts
,
state_dict
,
f
"transformer.layers.
{
i
}
.mlp.fc2.weight"
,
-
1
)
return
state_dict
def
remap_state_dict_hf_gpt2
(
state_dict
,
config
):
# Word embedding and position embedding
def
key_mapping_pos_emb
(
key
):
return
re
.
sub
(
r
"^wpe."
,
"transformer.embeddings.position_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_pos_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
"wte.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^ln_f.(weight|bias)"
,
r
"transformer.ln_f.\1"
,
key
)
key
=
re
.
sub
(
r
"^h.(\d+).ln_(1|2).(weight|bias)"
,
r
"transformer.layers.\1.norm\2.\3"
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
for
d
in
range
(
config
.
num_hidden_layers
):
W1
=
state_dict
.
pop
(
f
"h.
{
d
}
.mlp.c_fc.weight"
)
state_dict
[
f
"transformer.layers.
{
d
}
.mlp.fc1.weight"
]
=
W1
.
t
()
W2
=
state_dict
.
pop
(
f
"h.
{
d
}
.mlp.c_proj.weight"
)
state_dict
[
f
"transformer.layers.
{
d
}
.mlp.fc2.weight"
]
=
W2
.
t
()
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
"^h.(\d+).mlp.c_fc.bias"
,
r
"transformer.layers.\1.mlp.fc1.bias"
,
key
)
key
=
re
.
sub
(
r
"^h.(\d+).mlp.c_proj.bias"
,
r
"transformer.layers.\1.mlp.fc2.bias"
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
d
in
range
(
config
.
num_hidden_layers
):
state_dict
.
pop
(
f
"h.
{
d
}
.attn.bias"
,
None
)
# We don't store this bias
Wqkv
=
state_dict
.
pop
(
f
"h.
{
d
}
.attn.c_attn.weight"
)
state_dict
[
f
"transformer.layers.
{
d
}
.mixer.Wqkv.weight"
]
=
Wqkv
.
t
()
Wout
=
state_dict
.
pop
(
f
"h.
{
d
}
.attn.c_proj.weight"
)
state_dict
[
f
"transformer.layers.
{
d
}
.mixer.out_proj.weight"
]
=
Wout
.
t
()
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
"^h.(\d+).attn.c_attn.bias"
,
r
"transformer.layers.\1.mixer.Wqkv.bias"
,
key
)
key
=
re
.
sub
(
r
"^h.(\d+).attn.c_proj.bias"
,
r
"transformer.layers.\1.mixer.out_proj.bias"
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
remap_state_dict_megatron
(
state_dict
,
config
):
def
key_mapping_transformer
(
key
):
key
=
re
.
sub
(
r
"^language_model.encoder."
,
"transformer."
,
key
)
key
=
re
.
sub
(
r
"^language_model."
,
"transformer."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_transformer
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding and position embedding
def
key_mapping_pos_emb
(
key
):
return
re
.
sub
(
r
"^wpe."
,
"transformer.embeddings.position_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_pos_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
"transformer.embedding.word_embeddings.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
(
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^transformer.final_layernorm.(weight|bias)"
,
r
"transformer.ln_f.\1"
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).input_layernorm.(weight|bias)"
,
r
"transformer.layers.\1.norm1.\2"
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)"
,
r
"transformer.layers.\1.norm2.\2"
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)"
,
r
"transformer.layers.\1.mlp.fc1.\2"
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)"
,
r
"transformer.layers.\1.mlp.fc2.\2"
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq"
,
r
"transformer.layers.\1.mixer.rotary_emb.inv_freq"
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)"
,
r
"transformer.layers.\1.mixer.Wqkv.\2"
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).self_attention.dense.(weight|bias)"
,
r
"transformer.layers.\1.mixer.out_proj.\2"
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim
=
config
.
hidden_size
//
config
.
num_attention_heads
for
d
in
range
(
config
.
num_hidden_layers
):
Wqkv
=
state_dict
.
pop
(
f
"transformer.layers.
{
d
}
.mixer.Wqkv.weight"
)
state_dict
[
f
"transformer.layers.
{
d
}
.mixer.Wqkv.weight"
]
=
rearrange
(
Wqkv
,
"(nheads three headdim) ... -> (three nheads headdim) ..."
,
three
=
3
,
headdim
=
headdim
,
)
bqkv
=
state_dict
.
pop
(
f
"transformer.layers.
{
d
}
.mixer.Wqkv.bias"
)
state_dict
[
f
"transformer.layers.
{
d
}
.mixer.Wqkv.bias"
]
=
rearrange
(
bqkv
,
"(nheads three headdim) -> (three nheads headdim)"
,
three
=
3
,
headdim
=
headdim
)
return
state_dict
flash_attn/models/gpt_neox.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2023, Tri Dao.
import
math
import
re
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
GPT2Config
,
GPTNeoXConfig
def
remap_state_dict_hf_gpt_neox
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
return
re
.
sub
(
r
"^gpt_neox."
,
"transformer."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
"^transformer.embed_in."
,
"transformer.embeddings.word_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
"transformer.embeddings.word_embeddings.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
if
getattr
(
config
,
"tie_word_embeddings"
,
False
):
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
else
:
output_embeddings
=
state_dict
.
pop
(
"embed_out.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
"lm_head.weight"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^transformer.final_layer_norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).input_layernorm."
,
r
"transformer.layers.\1.norm1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).post_attention_layernorm."
,
r
"transformer.layers.\1.norm2."
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.dense_h_to_4h."
,
r
"transformer.layers.\1.mlp.fc1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.dense_4h_to_h."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
l
in
range
(
config
.
n_layer
):
# We don't store these biases
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.bias"
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.masked_bias"
)
# We don't store these
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.rotary_emb.inv_freq"
,
None
)
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim
=
config
.
hidden_size
//
config
.
num_attention_heads
Wqkv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.query_key_value.weight"
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
]
=
rearrange
(
Wqkv
,
"(nheads three headdim) ... -> (three nheads headdim) ..."
,
three
=
3
,
headdim
=
headdim
,
)
bqkv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.query_key_value.bias"
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.bias"
]
=
rearrange
(
bqkv
,
"(nheads three headdim) -> (three nheads headdim)"
,
three
=
3
,
headdim
=
headdim
)
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).attention.dense."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
gpt_neox_config_to_gpt2_config
(
gpt_neox_config
:
GPTNeoXConfig
)
->
GPT2Config
:
assert
gpt_neox_config
.
rotary_emb_base
==
10000
return
GPT2Config
(
vocab_size
=
gpt_neox_config
.
vocab_size
,
n_positions
=
0
,
# No absolute position embedding
n_embd
=
gpt_neox_config
.
hidden_size
,
n_layer
=
gpt_neox_config
.
num_hidden_layers
,
n_head
=
gpt_neox_config
.
num_attention_heads
,
n_inner
=
gpt_neox_config
.
intermediate_size
,
activation_function
=
gpt_neox_config
.
hidden_act
,
resid_pdrop
=
0.0
,
# No dropout
embd_pdrop
=
0.0
,
attn_pdrop
=
0.0
,
layer_norm_epsilon
=
gpt_neox_config
.
layer_norm_eps
,
initializer_range
=
gpt_neox_config
.
initializer_range
,
bos_token_id
=
gpt_neox_config
.
bos_token_id
,
eos_token_id
=
gpt_neox_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
prenorm
=
True
,
parallel_block
=
gpt_neox_config
.
use_parallel_residual
,
parallel_block_tied_norm
=
False
,
rotary_emb_fraction
=
gpt_neox_config
.
rotary_pct
,
tie_word_embeddings
=
gpt_neox_config
.
tie_word_embeddings
,
)
flash_attn/models/gptj.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2023, Tri Dao.
import
math
import
re
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
transformers
import
GPT2Config
,
GPTJConfig
def
remap_state_dict_hf_gptj
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
return
re
.
sub
(
r
"^transformer.h."
,
"transformer.layers."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
"^transformer.wte."
,
"transformer.embeddings.word_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
"transformer.embeddings.word_embeddings.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
if
getattr
(
config
,
"tie_word_embeddings"
):
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
else
:
output_embeddings
=
state_dict
.
pop
(
"lm_head.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
"lm_head.weight"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
output_embeddings_bias
=
state_dict
.
pop
(
"lm_head.bias"
)
state_dict
[
"lm_head.bias"
]
=
F
.
pad
(
output_embeddings_bias
,
(
0
,
vocab_size
-
output_embeddings_bias
.
shape
[
0
])
)
# LayerNorm
def
key_mapping_ln
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).ln_1."
,
r
"transformer.layers.\1.norm1."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.fc_in."
,
r
"transformer.layers.\1.mlp.fc1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.fc_out."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attn.q_proj.weight"
)
Wk
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attn.k_proj.weight"
)
Wv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attn.v_proj.weight"
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
# We don't store these biases
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attn.bias"
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attn.masked_bias"
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).attn.out_proj."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
gptj_config_to_gpt2_config
(
gptj_config
:
GPTJConfig
)
->
GPT2Config
:
headdim
=
gptj_config
.
n_embd
//
gptj_config
.
n_head
return
GPT2Config
(
vocab_size
=
gptj_config
.
vocab_size
,
n_positions
=
0
,
# No absolute position embedding
n_embd
=
gptj_config
.
n_embd
,
n_layer
=
gptj_config
.
n_layer
,
n_head
=
gptj_config
.
n_head
,
n_inner
=
gptj_config
.
n_inner
,
activation_function
=
gptj_config
.
activation_function
,
resid_pdrop
=
gptj_config
.
resid_pdrop
,
embd_pdrop
=
gptj_config
.
embd_pdrop
,
attn_pdrop
=
gptj_config
.
attn_pdrop
,
layer_norm_epsilon
=
gptj_config
.
layer_norm_epsilon
,
initializer_range
=
gptj_config
.
initializer_range
,
bos_token_id
=
gptj_config
.
bos_token_id
,
eos_token_id
=
gptj_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
prenorm
=
True
,
parallel_block
=
True
,
parallel_block_tied_norm
=
True
,
rotary_emb_fraction
=
gptj_config
.
rotary_dim
/
headdim
,
rotary_emb_interleaved
=
True
,
tie_word_embeddings
=
False
,
qkv_proj_bias
=
False
,
out_proj_bias
=
False
,
lm_head_bias
=
True
,
)
flash_attn/models/llama.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2023, Tri Dao.
import
json
import
math
import
os
import
re
from
collections
import
OrderedDict
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Union
import
torch
import
torch.nn.functional
as
F
from
sentencepiece
import
SentencePieceProcessor
from
transformers
import
GPT2Config
,
LlamaConfig
from
einops
import
rearrange
def
remap_state_dict_meta_llama
(
state_dict
:
Dict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Convert the state_dict in Meta format to standard GPT format.
This function modifies state_dict in place.
"""
def
key_mapping_layers
(
key
):
return
f
"transformer.
{
key
}
"
if
not
key
.
startswith
(
"output."
)
else
key
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
"^transformer.tok_embeddings."
,
"transformer.embeddings.word_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
"transformer.embeddings.word_embeddings.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
(
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
if
getattr
(
config
,
"tie_word_embeddings"
):
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
else
:
output_embeddings
=
state_dict
.
pop
(
"output.weight"
)
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size
=
(
math
.
ceil
(
output_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
"lm_head.weight"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^transformer.norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).attention_norm."
,
r
"transformer.layers.\1.norm1."
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).ffn_norm."
,
r
"transformer.layers.\1.norm2."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
for
l
in
range
(
config
.
n_layer
):
w1
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.feed_forward.w1.weight"
)
w3
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.feed_forward.w3.weight"
)
# Our ordering is different
state_dict
[
f
"transformer.layers.
{
l
}
.mlp.fc1.weight"
]
=
torch
.
cat
([
w3
,
w1
],
dim
=
0
)
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).feed_forward.w2."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.wq.weight"
)
Wk
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.wk.weight"
)
Wv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.wv.weight"
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
# We don't store these
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.inner_attention.rope.freqs"
,
None
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).attention.wo."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
.
pop
(
"transformer.rope.freqs"
,
None
)
return
state_dict
def
remap_state_dict_hf_llama
(
state_dict
:
Dict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Convert the state_dict in Hugging Face format to standard GPT format.
This function modifies state_dict in place.
"""
# Embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
"^model.embed_tokens."
,
"transformer.embeddings.word_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
"transformer.embeddings.word_embeddings.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
(
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
# LM head
if
getattr
(
config
,
"tie_word_embeddings"
):
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
else
:
output_embeddings
=
state_dict
.
pop
(
"lm_head.weight"
)
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size
=
(
math
.
ceil
(
output_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
"lm_head.weight"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
# MLP
for
l
in
range
(
config
.
n_layer
):
# Fusing weights this way based on difference in the following:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
# https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
w1
=
state_dict
.
pop
(
f
"model.layers.
{
l
}
.mlp.gate_proj.weight"
)
w3
=
state_dict
.
pop
(
f
"model.layers.
{
l
}
.mlp.up_proj.weight"
)
state_dict
[
f
"transformer.layers.
{
l
}
.mlp.fc1.weight"
]
=
torch
.
cat
([
w3
,
w1
],
dim
=
0
)
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
"^model.layers.(\d+).mlp.down_proj."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^model.norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
"^model.layers.(\d+).input_layernorm."
,
r
"transformer.layers.\1.norm1."
,
key
,
)
key
=
re
.
sub
(
r
"^model.layers.(\d+).post_attention_layernorm."
,
r
"transformer.layers.\1.norm2."
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
def
inv_permute
(
w
):
# Inverse of permute implemented in:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
return
rearrange
(
w
,
"(h two d) n -> (h d two) n"
,
d
=
config
.
n_embd
//
config
.
n_head
//
2
,
two
=
2
)
# Attention
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
"model.layers.
{
l
}
.self_attn.q_proj.weight"
)
Wk
=
state_dict
.
pop
(
f
"model.layers.
{
l
}
.self_attn.k_proj.weight"
)
Wv
=
state_dict
.
pop
(
f
"model.layers.
{
l
}
.self_attn.v_proj.weight"
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
]
=
torch
.
cat
(
[
inv_permute
(
Wq
),
inv_permute
(
Wk
),
Wv
],
dim
=
0
)
# We don't store these
state_dict
.
pop
(
f
"model.layers.
{
l
}
.self_attn.rotary_emb.inv_freq"
,
None
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
"^model.layers.(\d+).self_attn.o_proj."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
inv_remap_state_dict_hf_llama
(
state_dict
:
Dict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Convert the state_dict in standard GPT format to Hugging Face format.
This function is meant to be the inverse of remap_state_dict_hf_llama, up to a
multiplier pad in the embedding and lm_head. That is if the original embedding
isn't a multiple of pad_vocab_size_multiple, then
inv_remap_state_dict_hf_llama(remap_state_dict_hf_llama(state_dict)) != state_dict.
This function modifies state_dict in place.
"""
# Embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
"^transformer.embeddings.word_embeddings."
,
"model.embed_tokens."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
"model.embed_tokens.weight"
)
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
(
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
"model.embed_tokens.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
# LM head
if
getattr
(
config
,
"tie_word_embeddings"
):
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"model.embed_tokens.weight"
]
else
:
output_embeddings
=
state_dict
.
pop
(
"lm_head.weight"
)
vocab_size
=
(
math
.
ceil
(
output_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
"lm_head.weight"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
# MLP
for
l
in
range
(
config
.
n_layer
):
w3
,
w1
=
torch
.
chunk
(
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.mlp.fc1.weight"
),
chunks
=
2
,
dim
=
0
)
state_dict
[
f
"model.layers.
{
l
}
.mlp.gate_proj.weight"
]
=
w1
state_dict
[
f
"model.layers.
{
l
}
.mlp.up_proj.weight"
]
=
w3
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.fc2."
,
r
"model.layers.\1.mlp.down_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^transformer.ln_f."
,
r
"model.norm."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).norm1."
,
r
"model.layers.\1.input_layernorm."
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).norm2."
,
r
"model.layers.\1.post_attention_layernorm."
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
def
permute
(
w
):
return
rearrange
(
w
,
"(h d two) n -> (h two d) n"
,
d
=
config
.
n_embd
//
config
.
n_head
//
2
,
two
=
2
)
n_head
=
config
.
n_head
n_head_kv
=
getattr
(
config
,
"n_head_kv"
,
n_head
)
embed_dim
=
config
.
hidden_size
head_dim
=
embed_dim
//
n_head
q_dim
=
n_head
*
head_dim
k_dim
=
v_dim
=
n_head_kv
*
head_dim
# Attention
for
l
in
range
(
config
.
n_layer
):
Wqkv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
)
Wq
=
Wqkv
[:
q_dim
]
Wk
=
Wqkv
[
q_dim
:
q_dim
+
k_dim
]
Wv
=
Wqkv
[
q_dim
+
k_dim
:
q_dim
+
k_dim
+
v_dim
]
state_dict
[
f
"model.layers.
{
l
}
.self_attn.q_proj.weight"
]
=
permute
(
Wq
)
state_dict
[
f
"model.layers.
{
l
}
.self_attn.k_proj.weight"
]
=
permute
(
Wk
)
state_dict
[
f
"model.layers.
{
l
}
.self_attn.v_proj.weight"
]
=
Wv
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.inner_attention.rope.freqs"
,
None
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).mixer.out_proj."
,
r
"model.layers.\1.self_attn.o_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
config_from_meta_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
LlamaConfig
:
"""Load a LlamaConfig from a checkpoint path."""
with
open
(
Path
(
checkpoint_path
)
/
model_name
/
"params.json"
)
as
f
:
params
=
json
.
load
(
f
)
config
=
LlamaConfig
(
hidden_size
=
params
[
"dim"
],
intermediate_size
=
None
,
num_attention_heads
=
params
[
"n_heads"
],
num_hidden_layers
=
params
[
"n_layers"
],
rms_norm_eps
=
params
[
"norm_eps"
],
num_key_value_heads
=
params
.
get
(
"n_kv_heads"
,
None
),
)
multiple_of
=
params
.
get
(
"multiple_of"
,
1
)
ffn_dim_multiplier
=
params
.
get
(
"ffn_dim_multiplier"
,
None
)
# Compute the hidden dimension of the MLP
# https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L224
intermediate_size
=
4
*
config
.
hidden_size
# https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L195-L199
intermediate_size
=
int
(
2
*
intermediate_size
/
3
)
# custom dim factor multiplier
if
ffn_dim_multiplier
is
not
None
:
intermediate_size
=
int
(
ffn_dim_multiplier
*
intermediate_size
)
intermediate_size
=
multiple_of
*
((
intermediate_size
+
multiple_of
-
1
)
//
multiple_of
)
config
.
intermediate_size
=
intermediate_size
if
"rope_theta"
in
params
:
config
.
rotary_emb_base
=
params
[
"rope_theta"
]
config
.
vocab_size
=
32000
# some CodeLLaMa have vocab_size 32000, some 32016
# Sadly it's not specified in the `params.json` file :(
tokenizer
=
Path
(
checkpoint_path
)
/
model_name
/
"tokenizer.model"
if
tokenizer
.
is_file
():
config
.
vocab_size
=
SentencePieceProcessor
(
str
(
tokenizer
)).
vocab_size
()
return
config
def
config_from_hf_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
LlamaConfig
:
return
LlamaConfig
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
/
"config.json"
)
def
config_from_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
,
checkpoint_format
=
"meta"
)
->
LlamaConfig
:
if
checkpoint_format
==
"meta"
:
return
config_from_meta_checkpoint
(
checkpoint_path
,
model_name
)
else
:
return
config_from_hf_checkpoint
(
checkpoint_path
,
model_name
)
def
state_dicts_from_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
List
[
dict
]:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return
[
torch
.
load
(
path
,
map_location
=
"cpu"
)
for
path
in
sorted
((
Path
(
checkpoint_path
)
/
model_name
).
glob
(
"consolidated.*.pth"
))
]
def
llama_config_to_gpt2_config
(
llama_config
:
LlamaConfig
)
->
GPT2Config
:
return
GPT2Config
(
vocab_size
=
llama_config
.
vocab_size
,
n_positions
=
0
,
# No absolute position embedding
n_embd
=
llama_config
.
hidden_size
,
n_layer
=
llama_config
.
num_hidden_layers
,
n_head
=
llama_config
.
num_attention_heads
,
n_inner
=
llama_config
.
intermediate_size
,
activation_function
=
"swiglu"
,
# Hardcode since HF calls it 'silu'
# Llama doesn't have dropout, idk if it's because they only release the inference code
resid_pdrop
=
0.0
,
embd_pdrop
=
0.0
,
attn_pdrop
=
0.0
,
layer_norm_epsilon
=
llama_config
.
rms_norm_eps
,
initializer_range
=
llama_config
.
initializer_range
,
bos_token_id
=
llama_config
.
bos_token_id
,
eos_token_id
=
llama_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
pad_token_id
=
llama_config
.
pad_token_id
,
# Idk if this does anything
rms_norm
=
True
,
rotary_emb_fraction
=
1.0
,
rotary_emb_interleaved
=
True
,
tie_word_embeddings
=
False
,
qkv_proj_bias
=
False
,
out_proj_bias
=
False
,
mlp_fc1_bias
=
False
,
mlp_fc2_bias
=
False
,
rotary_emb_base
=
getattr
(
llama_config
,
"rotary_emb_base"
,
10000.0
),
n_head_kv
=
llama_config
.
num_key_value_heads
,
)
flash_attn/models/opt.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2023, Tri Dao.
import
math
import
re
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
transformers
import
GPT2Config
,
OPTConfig
def
remap_state_dict_hf_opt
(
state_dict
,
config
):
def
key_mapping_model
(
key
):
key
=
re
.
sub
(
r
"^model.decoder."
,
"transformer."
,
key
)
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
key
=
re
.
sub
(
r
"^decoder."
,
"transformer."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_model
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding and position embedding
def
key_mapping_emb
(
key
):
key
=
re
.
sub
(
r
"^transformer.embed_tokens."
,
"transformer.embeddings.word_embeddings."
,
key
)
# The OPT-350m model uses has project_in and project_out
key
=
re
.
sub
(
r
"^transformer.project_in."
,
"transformer.embeddings.project_in."
,
key
)
key
=
re
.
sub
(
r
"^transformer.project_out."
,
"project_out."
,
key
)
key
=
re
.
sub
(
r
"^transformer.embed_positions."
,
"transformer.embeddings.position_embeddings."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# OPT uses the first 2 indices of pos_emb for padding tokens
pos_embeddings
=
state_dict
.
pop
(
"transformer.embeddings.position_embeddings.weight"
)
state_dict
[
"transformer.embeddings.position_embeddings.weight"
]
=
pos_embeddings
[
2
:]
word_embeddings
=
state_dict
.
pop
(
"transformer.embeddings.word_embeddings.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^transformer.final_layer_norm."
,
r
"transformer.ln_f."
,
key
)
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
key
=
re
.
sub
(
r
"^transformer.layer_norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).self_attn_layer_norm."
,
r
"transformer.layers.\1.norm1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).final_layer_norm."
,
r
"transformer.layers.\1.norm2."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).fc(1|2)."
,
r
"transformer.layers.\1.mlp.fc\2."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.q_proj.weight"
)
Wk
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.k_proj.weight"
)
Wv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.v_proj.weight"
)
bq
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.q_proj.bias"
)
bk
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.k_proj.bias"
)
bv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.v_proj.bias"
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.bias"
]
=
torch
.
cat
([
bq
,
bk
,
bv
],
dim
=
0
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).self_attn.out_proj."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
opt_config_to_gpt2_config
(
opt_config
:
OPTConfig
)
->
GPT2Config
:
assert
opt_config
.
layerdrop
==
0.0
assert
opt_config
.
layer_norm_elementwise_affine
word_embed_proj_dim
=
(
None
if
opt_config
.
word_embed_proj_dim
==
opt_config
.
hidden_size
else
opt_config
.
word_embed_proj_dim
)
return
GPT2Config
(
vocab_size
=
opt_config
.
vocab_size
,
n_positions
=
opt_config
.
max_position_embeddings
,
n_embd
=
opt_config
.
hidden_size
,
n_layer
=
opt_config
.
num_hidden_layers
,
n_head
=
opt_config
.
num_attention_heads
,
n_inner
=
opt_config
.
ffn_dim
,
activation_function
=
opt_config
.
activation_function
,
resid_pdrop
=
opt_config
.
dropout
,
# HF's implementation of OPT doesn't seem to have embedding dropout
embd_pdrop
=
opt_config
.
dropout
,
attn_pdrop
=
opt_config
.
attention_dropout
,
initializer_range
=
opt_config
.
init_std
,
bos_token_id
=
opt_config
.
bos_token_id
,
eos_token_id
=
opt_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
prenorm
=
opt_config
.
do_layer_norm_before
,
word_embed_proj_dim
=
word_embed_proj_dim
,
)
flash_attn/models/vit.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2022, Tri Dao.
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import
math
import
re
from
collections
import
OrderedDict
from
copy
import
deepcopy
from
functools
import
partial
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
timm.models.helpers
import
named_apply
from
torch.nn.init
import
trunc_normal_
from
torchvision.ops
import
StochasticDepth
from
flash_attn.layers.patch_embed
import
PatchEmbed
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
FusedMLP
,
Mlp
try
:
from
flash_attn.ops.triton.layer_norm
import
layer_norm_fn
except
ImportError
:
layer_norm_fn
=
None
def
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop
,
use_flash_attn
,
fused_bias_fc
,
cross_attn
=
False
):
mixer_cls
=
partial
(
MHA
,
num_heads
=
num_heads
,
cross_attn
=
cross_attn
,
qkv_proj_bias
=
qkv_bias
,
dropout
=
attn_drop
,
fused_bias_fc
=
fused_bias_fc
,
use_flash_attn
=
use_flash_attn
,
)
return
mixer_cls
def
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_mlp
):
inner_dim
=
int
(
embed_dim
*
mlp_ratio
)
if
not
fused_mlp
:
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
act_layer
())
else
:
mlp_cls
=
partial
(
FusedMLP
,
hidden_features
=
inner_dim
)
return
mlp_cls
def
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path1
,
drop_path2
,
norm_layer
,
act_layer
,
use_flash_attn
,
fused_bias_fc
,
fused_mlp
,
fused_dropout_add_ln
,
layer_idx
=
None
,
n_layer
=
None
,
last_layer_subset
=
False
,
):
mixer_cls
=
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop_rate
,
use_flash_attn
,
fused_bias_fc
,
cross_attn
=
(
last_layer_subset
and
layer_idx
==
n_layer
-
1
),
)
mlp_cls
=
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_mlp
)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
block
=
Block
(
embed_dim
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_layer
,
prenorm
=
True
,
resid_dropout1
=
drop_rate
,
resid_dropout2
=
drop_rate
,
drop_path1
=
drop_path1
,
drop_path2
=
drop_path2
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
residual_in_fp32
=
True
,
)
return
block
class
VisionTransformer
(
nn
.
Module
):
"""Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
num_classes
=
1000
,
global_pool
=
"token"
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4.0
,
qkv_bias
=
True
,
init_values
=
None
,
class_token
=
True
,
no_embed_class
=
False
,
pre_norm
=
False
,
fc_norm
=
None
,
drop_rate
=
0.0
,
attn_drop_rate
=
0.0
,
drop_path_rate
=
0.0
,
weight_init
=
""
,
embed_layer
=
PatchEmbed
,
norm_layer
=
None
,
act_layer
=
None
,
use_flash_attn
=
False
,
fused_bias_fc
=
False
,
fused_mlp
=
False
,
fused_dropout_add_ln
=
False
,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'token')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
init_values: (float): layer-scale init values
class_token (bool): use class token
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
"""
super
().
__init__
()
assert
global_pool
==
"token"
,
"Only support pooling with CLS token"
assert
class_token
assert
init_values
is
None
,
"LayerScale is not supported yet"
assert
weight_init
==
""
assert
fc_norm
is
None
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
assert
not
pre_norm
use_fc_norm
=
global_pool
==
"avg"
if
fc_norm
is
None
else
fc_norm
norm_layer
=
norm_layer
or
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
act_layer
=
act_layer
or
nn
.
GELU
self
.
num_classes
=
num_classes
self
.
global_pool
=
global_pool
self
.
num_features
=
(
self
.
embed_dim
)
=
embed_dim
# num_features for consistency with other models
self
.
num_prefix_tokens
=
1
if
class_token
else
0
self
.
no_embed_class
=
no_embed_class
patch_embed_extra_kwargs
=
(
{
"fused_bias_fc"
:
fused_bias_fc
}
if
embed_layer
is
PatchEmbed
else
{}
)
self
.
patch_embed
=
embed_layer
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
bias
=
not
pre_norm
,
# disable bias if pre-norm is used (e.g. CLIP)
**
patch_embed_extra_kwargs
,
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
if
class_token
else
None
embed_len
=
num_patches
if
no_embed_class
else
num_patches
+
self
.
num_prefix_tokens
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
randn
(
1
,
embed_len
,
embed_dim
)
*
0.02
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)
]
# stochastic depth decay rule
# We change the order of dropout, residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
# nn.Dropout probabilities are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
self
.
blocks
=
nn
.
ModuleList
(
[
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path1
=
dpr
[
i
-
1
]
if
i
>
0
else
0.0
,
drop_path2
=
dpr
[
i
],
norm_layer
=
norm_layer
,
act_layer
=
act_layer
,
use_flash_attn
=
use_flash_attn
,
fused_bias_fc
=
fused_bias_fc
,
fused_mlp
=
fused_mlp
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
layer_idx
=
i
,
n_layer
=
depth
,
last_layer_subset
=
(
global_pool
==
"token"
),
)
for
i
in
range
(
depth
)
]
)
self
.
dropout
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
drop_path
=
StochasticDepth
(
p
=
dpr
[
-
1
],
mode
=
"row"
)
self
.
norm
=
norm_layer
(
embed_dim
)
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
if
self
.
fused_dropout_add_ln
and
layer_norm_fn
is
None
:
raise
ImportError
(
"Triton is not installed"
)
# Classifier Head
self
.
head
=
nn
.
Linear
(
self
.
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
self
.
init_weights
(
weight_init
)
def
init_weights
(
self
,
mode
=
""
):
assert
mode
==
""
trunc_normal_
(
self
.
pos_embed
,
std
=
0.02
)
if
self
.
cls_token
is
not
None
:
nn
.
init
.
normal_
(
self
.
cls_token
,
std
=
1e-6
)
named_apply
(
init_weights_vit_timm
,
self
)
def
_init_weights
(
self
,
m
):
# this fn left here for compat with downstream users
init_weights_vit_timm
(
m
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
"pos_embed"
,
"cls_token"
}
def
_pos_embed
(
self
,
x
):
if
self
.
no_embed_class
:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x
=
x
+
self
.
pos_embed
if
self
.
cls_token
is
not
None
:
x
=
torch
.
cat
((
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
),
x
),
dim
=
1
)
else
:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if
self
.
cls_token
is
not
None
:
x
=
torch
.
cat
((
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
),
x
),
dim
=
1
)
x
=
x
+
self
.
pos_embed
return
x
def
forward_features
(
self
,
x
,
all_tokens
=
True
):
"""
If all_tokens==False and self.global_pool == 'token', we only return the features for the
cls token.
"""
x
=
self
.
patch_embed
(
x
)
hidden_states
=
self
.
_pos_embed
(
x
)
residual
=
None
if
self
.
global_pool
!=
"token"
or
all_tokens
:
# if True:
for
block
in
self
.
blocks
:
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
else
:
for
block
in
self
.
blocks
[:
-
1
]:
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
# For the last layer, we only want the 1st token of the output. So we do cross-attention
# where the query is the 1st token and the key/value is the whole sequence.
hidden_states
,
residual
=
self
.
blocks
[
-
1
](
hidden_states
,
residual
,
mixer_subset
=
slice
(
0
,
1
)
)
if
not
self
.
fused_dropout_add_ln
:
residual
=
self
.
drop_path
(
self
.
dropout
(
hidden_states
))
+
residual
hidden_states
=
self
.
norm
(
residual
.
to
(
dtype
=
self
.
norm
.
weight
.
dtype
))
else
:
if
self
.
drop_path
.
p
==
0
or
not
self
.
training
:
rowscale
=
None
else
:
rowscale
=
self
.
drop_path
(
torch
.
ones
(
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
)
# Set prenorm=False here since we don't need to the residual
hidden_states
=
layer_norm_fn
(
hidden_states
,
self
.
norm
.
weight
,
self
.
norm
.
bias
,
residual
=
residual
,
eps
=
self
.
norm
.
eps
,
dropout_p
=
self
.
dropout
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale
,
prenorm
=
False
,
)
return
hidden_states
def
forward_head
(
self
,
x
,
pre_logits
:
bool
=
False
):
if
self
.
global_pool
:
x
=
x
[:,
self
.
num_prefix_tokens
:].
mean
(
dim
=
1
)
if
self
.
global_pool
==
"avg"
else
x
[:,
0
]
return
x
if
pre_logits
else
self
.
head
(
x
)
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
,
all_tokens
=
False
)
x
=
self
.
forward_head
(
x
)
return
x
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
patch_embed_weight
=
state_dict
[
"patch_embed.proj.weight"
]
if
patch_embed_weight
.
dim
()
==
4
:
# convert from Conv2d to Linear
state_dict
[
"patch_embed.proj.weight"
]
=
rearrange
(
patch_embed_weight
,
"o c h w -> o (c h w)"
)
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
"^blocks.(\d+).attn.qkv."
,
r
"blocks.\1.mixer.Wqkv."
,
key
)
key
=
re
.
sub
(
r
"^blocks.(\d+).attn.proj."
,
r
"blocks.\1.mixer.out_proj."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
n_layer
=
len
(
self
.
blocks
)
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
if
(
self
.
blocks
[
-
1
].
mixer
.
cross_attn
and
f
"blocks.
{
n_layer
-
1
}
.mixer.Wqkv.weight"
in
state_dict
):
Wqkv
=
state_dict
.
pop
(
f
"blocks.
{
n_layer
-
1
}
.mixer.Wqkv.weight"
)
bqkv
=
state_dict
.
pop
(
f
"blocks.
{
n_layer
-
1
}
.mixer.Wqkv.bias"
)
state_dict
[
f
"blocks.
{
n_layer
-
1
}
.mixer.Wq.weight"
]
=
Wqkv
[:
self
.
embed_dim
]
state_dict
[
f
"blocks.
{
n_layer
-
1
}
.mixer.Wkv.weight"
]
=
Wqkv
[
self
.
embed_dim
:]
state_dict
[
f
"blocks.
{
n_layer
-
1
}
.mixer.Wq.bias"
]
=
bqkv
[:
self
.
embed_dim
]
state_dict
[
f
"blocks.
{
n_layer
-
1
}
.mixer.Wkv.bias"
]
=
bqkv
[
self
.
embed_dim
:]
return
super
().
load_state_dict
(
state_dict
,
strict
=
strict
)
def
init_weights_vit_timm
(
module
:
nn
.
Module
,
name
:
str
=
""
):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if
isinstance
(
module
,
nn
.
Linear
):
trunc_normal_
(
module
.
weight
,
std
=
0.02
)
if
module
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
module
.
bias
)
elif
hasattr
(
module
,
"init_weights"
):
module
.
init_weights
()
def
vit_base_patch16_224
(
pretrained
=
False
,
**
kwargs
):
"""ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
assert
not
pretrained
model_kwargs
=
dict
(
patch_size
=
16
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
**
kwargs
)
model
=
VisionTransformer
(
**
model_kwargs
)
return
model
flash_attn/modules/__init__.py
deleted
100644 → 0
View file @
5018ac6a
flash_attn/modules/block.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2024, Tri Dao.
from
functools
import
partial
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torchvision.ops
import
StochasticDepth
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
Mlp
try
:
from
flash_attn.ops.triton.layer_norm
import
layer_norm_fn
,
RMSNorm
except
ImportError
:
layer_norm_fn
,
RMSNorm
=
None
,
None
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
mixer_cls
=
None
,
mlp_cls
=
None
,
norm_cls
=
nn
.
LayerNorm
,
dropout_cls
=
nn
.
Dropout
,
prenorm
=
True
,
resid_dropout1
=
0.0
,
resid_dropout2
=
0.0
,
drop_path1
=
0.0
,
drop_path2
=
0.0
,
fused_dropout_add_ln
=
False
,
return_residual
=
False
,
residual_in_fp32
=
False
,
sequence_parallel
=
False
,
mark_shared_params
=
False
,
):
"""
For prenorm=True, this Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
the hidden_states (output of the MLP) and the residual.
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
The residual needs to be provided (except for the very first block).
For prenorm=False, this Block has the same structure as a regular postnorm Transformer
block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
This is for performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
super
().
__init__
()
self
.
prenorm
=
prenorm
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
self
.
return_residual
=
return_residual
self
.
residual_in_fp32
=
residual_in_fp32
if
self
.
residual_in_fp32
:
assert
self
.
prenorm
,
"residual_in_fp32 is only compatible with prenorm=True"
if
mixer_cls
is
None
:
mixer_cls
=
partial
(
MHA
,
num_heads
=
dim
//
64
)
if
mlp_cls
is
None
:
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
4
*
dim
)
self
.
mixer
=
mixer_cls
(
dim
)
self
.
dropout1
=
dropout_cls
(
resid_dropout1
)
self
.
drop_path1
=
StochasticDepth
(
drop_path1
,
mode
=
"row"
)
self
.
norm1
=
norm_cls
(
dim
)
self
.
mlp
=
mlp_cls
(
dim
)
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
self
.
dropout2
=
dropout_cls
(
resid_dropout2
)
self
.
drop_path2
=
StochasticDepth
(
drop_path2
,
mode
=
"row"
)
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
assert
layer_norm_fn
is
not
None
,
"Triton is not installed"
assert
isinstance
(
self
.
norm1
,
(
nn
.
LayerNorm
,
RMSNorm
))
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
# For now this is not an issue because we always use sequence_parallel=True during training
# and only use sequence_parallel=False during inference.
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if
sequence_parallel
:
for
p
in
self
.
norm1
.
parameters
():
p
.
_sequence_parallel
=
True
if
hasattr
(
self
,
"norm2"
):
for
p
in
self
.
norm2
.
parameters
():
p
.
_sequence_parallel
=
True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if
mark_shared_params
:
for
p
in
self
.
norm1
.
parameters
():
p
.
_shared_params
=
True
if
hasattr
(
self
,
"norm2"
):
for
p
in
self
.
norm2
.
parameters
():
p
.
_shared_params
=
True
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
self
.
mixer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
def
forward
(
self
,
hidden_states
:
Tensor
,
residual
:
Optional
[
Tensor
]
=
None
,
mixer_subset
=
None
,
mixer_kwargs
=
None
,
):
r
"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
"""
if
self
.
prenorm
:
if
not
self
.
fused_dropout_add_ln
:
dropped
=
self
.
drop_path1
(
self
.
dropout1
(
hidden_states
))
residual
=
(
dropped
+
residual
)
if
residual
is
not
None
else
dropped
hidden_states
=
self
.
norm1
(
residual
.
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
))
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
else
:
if
self
.
drop_path1
.
p
==
0
or
not
self
.
training
:
rowscale1
=
None
else
:
rowscale1
=
self
.
drop_path1
(
torch
.
ones
(
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
)
hidden_states
,
residual
=
layer_norm_fn
(
hidden_states
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
residual
=
residual
,
eps
=
self
.
norm1
.
eps
,
dropout_p
=
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale1
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
is_rms_norm
=
isinstance
(
self
.
norm1
,
RMSNorm
)
)
if
mixer_kwargs
is
None
:
mixer_kwargs
=
{}
if
mixer_subset
is
not
None
:
mixer_kwargs
[
"mixer_subset"
]
=
mixer_subset
hidden_states
=
self
.
mixer
(
hidden_states
,
**
mixer_kwargs
)
if
mixer_subset
is
not
None
:
residual
=
residual
[:,
mixer_subset
]
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
if
not
self
.
fused_dropout_add_ln
:
dropped
=
self
.
drop_path2
(
self
.
dropout2
(
hidden_states
))
residual
=
(
dropped
+
residual
)
if
residual
is
not
None
else
dropped
hidden_states
=
self
.
norm2
(
residual
.
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
else
:
if
self
.
drop_path2
.
p
==
0
or
not
self
.
training
:
rowscale2
=
None
else
:
rowscale2
=
self
.
drop_path2
(
torch
.
ones
(
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
)
hidden_states
,
residual
=
layer_norm_fn
(
hidden_states
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
residual
=
residual
,
eps
=
self
.
norm2
.
eps
,
dropout_p
=
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale2
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
is_rms_norm
=
isinstance
(
self
.
norm2
,
RMSNorm
)
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
else
:
assert
residual
is
None
mixer_out
=
self
.
mixer
(
hidden_states
,
**
(
mixer_kwargs
if
mixer_kwargs
is
not
None
else
{})
)
if
self
.
return_residual
:
# mixer out is actually a pair here
mixer_out
,
hidden_states
=
mixer_out
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
norm1
(
(
self
.
drop_path1
(
self
.
dropout1
(
mixer_out
))
+
hidden_states
).
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
)
)
else
:
if
self
.
drop_path1
.
p
==
0
or
not
self
.
training
:
rowscale1
=
None
else
:
rowscale1
=
self
.
drop_path1
(
torch
.
ones
(
mixer_out
.
shape
[:
-
1
],
device
=
mixer_out
.
device
,
dtype
=
mixer_out
.
dtype
)
)
hidden_states
=
layer_norm_fn
(
mixer_out
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
residual
=
hidden_states
,
eps
=
self
.
norm1
.
eps
,
dropout_p
=
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale1
,
prenorm
=
False
,
is_rms_norm
=
isinstance
(
self
.
norm1
,
RMSNorm
)
)
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
mlp_out
=
self
.
mlp
(
hidden_states
)
if
self
.
return_residual
:
# mlp out is actually a pair here
mlp_out
,
hidden_states
=
mlp_out
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
norm2
(
(
self
.
drop_path2
(
self
.
dropout2
(
mlp_out
))
+
hidden_states
).
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
)
)
else
:
if
self
.
drop_path2
.
p
==
0
or
not
self
.
training
:
rowscale2
=
None
else
:
rowscale2
=
self
.
drop_path2
(
torch
.
ones
(
mlp_out
.
shape
[:
-
1
],
device
=
mlp_out
.
device
,
dtype
=
mlp_out
.
dtype
)
)
hidden_states
=
layer_norm_fn
(
mlp_out
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
residual
=
hidden_states
,
eps
=
self
.
norm2
.
eps
,
dropout_p
=
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale2
,
prenorm
=
False
,
is_rms_norm
=
isinstance
(
self
.
norm2
,
RMSNorm
)
)
return
hidden_states
class
ParallelBlock
(
nn
.
Module
):
"""The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
and PaLM.
"""
def
__init__
(
self
,
dim
,
mixer_cls
=
None
,
mlp_cls
=
None
,
norm_cls
=
nn
.
LayerNorm
,
dropout_cls
=
nn
.
Dropout
,
resid_dropout1
=
0.0
,
resid_dropout2
=
0.0
,
tied_norm
=
False
,
fused_dropout_add_ln
=
False
,
residual_in_fp32
=
False
,
sequence_parallel
=
False
,
mark_shared_params
=
False
,
):
"""
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA / MLP -> Dropout -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
the hidden_states (output1 of the MHA / MLP) and the residual.
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super
().
__init__
()
self
.
tied_norm
=
tied_norm
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
self
.
residual_in_fp32
=
residual_in_fp32
if
mixer_cls
is
None
:
mixer_cls
=
partial
(
MHA
,
num_heads
=
dim
//
64
)
if
mlp_cls
is
None
:
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
4
*
dim
)
self
.
mixer
=
mixer_cls
(
dim
)
self
.
dropout1
=
dropout_cls
(
resid_dropout1
)
self
.
norm1
=
norm_cls
(
dim
)
self
.
mlp
=
mlp_cls
(
dim
)
self
.
dropout2
=
dropout_cls
(
resid_dropout2
)
if
not
self
.
tied_norm
:
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
assert
layer_norm_fn
is
not
None
,
"Triton is not installed"
assert
isinstance
(
self
.
norm1
,
(
nn
.
LayerNorm
,
RMSNorm
))
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
# For now this is not an issue because we always use sequence_parallel=True during training
# and only use sequence_parallel=False during inference.
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if
sequence_parallel
:
for
p
in
self
.
norm1
.
parameters
():
p
.
_sequence_parallel
=
True
if
hasattr
(
self
,
"norm2"
):
for
p
in
self
.
norm2
.
parameters
():
p
.
_sequence_parallel
=
True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if
mark_shared_params
:
for
p
in
self
.
norm1
.
parameters
():
p
.
_shared_params
=
True
if
hasattr
(
self
,
"norm2"
):
for
p
in
self
.
norm2
.
parameters
():
p
.
_shared_params
=
True
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
self
.
mixer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
def
forward
(
self
,
hidden_states1
:
Tensor
,
hidden_states2
:
Optional
[
Tensor
]
=
None
,
residual
:
Optional
[
Tensor
]
=
None
,
mixer_kwargs
=
None
,
):
r
"""Pass the input through the encoder layer.
Args:
hidden_states1: the output of the previous attention (mixer) or embedding layer.
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
residual.
"""
# TODO: Ideally we should only do the allgather / allreduce once for
# the Linear to MLP & Attention
if
not
self
.
fused_dropout_add_ln
:
dropped1
=
self
.
dropout1
(
hidden_states1
)
# For the very 1st block, we only want 1 dropout, not two different dropouts
if
hidden_states2
is
not
None
:
dropped2
=
self
.
dropout2
(
hidden_states2
)
residual
=
(
(
residual
+
dropped1
+
dropped2
)
if
residual
is
not
None
else
dropped1
+
dropped2
)
else
:
residual
=
(
residual
+
dropped1
)
if
residual
is
not
None
else
dropped1
hidden_states1
=
self
.
norm1
(
residual
.
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
))
hidden_states2
=
(
self
.
norm2
(
residual
.
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
if
not
self
.
tied_norm
else
hidden_states1
)
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
else
:
weight2
,
bias2
=
(
(
self
.
norm2
.
weight
,
self
.
norm2
.
bias
)
if
not
self
.
tied_norm
else
(
None
,
None
)
)
hidden_states1
,
*
rest
,
residual
=
layer_norm_fn
(
hidden_states1
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
residual
=
residual
,
x1
=
hidden_states2
,
weight1
=
weight2
,
bias1
=
bias2
,
eps
=
self
.
norm1
.
eps
,
dropout_p
=
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
is_rms_norm
=
isinstance
(
self
.
norm1
,
RMSNorm
)
)
if
self
.
tied_norm
:
hidden_states2
=
hidden_states1
else
:
hidden_states2
,
=
rest
if
mixer_kwargs
is
None
:
mixer_kwargs
=
{}
hidden_states1
=
self
.
mixer
(
hidden_states1
,
**
mixer_kwargs
)
hidden_states2
=
self
.
mlp
(
hidden_states2
)
return
hidden_states1
,
hidden_states2
,
residual
flash_attn/modules/embedding.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2022, Tri Dao.
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
from
torch
import
Tensor
from
flash_attn.utils.distributed
import
all_reduce
,
reduce_scatter
class
GPT2Embeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
padding_idx
=
None
,
word_embed_proj_dim
=
None
,
device
=
None
,
dtype
=
None
,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
the project up to embed_dim
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
if
word_embed_proj_dim
is
None
:
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
project_in
=
None
else
:
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
word_embed_proj_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
project_in
=
nn
.
Linear
(
word_embed_proj_dim
,
embed_dim
,
bias
=
False
,
**
factory_kwargs
)
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
,
**
factory_kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size
,
seqlen
=
input_ids
.
shape
embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
project_in
is
not
None
:
embeddings
=
self
.
project_in
(
embeddings
)
if
self
.
max_position_embeddings
>
0
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
embeddings
+
position_embeddings
return
embeddings
class
BertEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
type_vocab_size
,
padding_idx
=
None
,
device
=
None
,
dtype
=
None
,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
,
**
factory_kwargs
)
if
self
.
type_vocab_size
>
0
:
self
.
token_type_embeddings
=
nn
.
Embedding
(
type_vocab_size
,
embed_dim
,
**
factory_kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
token_type_ids: (batch, seqlen)
"""
batch_size
,
seqlen
=
input_ids
.
shape
embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
max_position_embeddings
>
0
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
embeddings
+
position_embeddings
if
self
.
type_vocab_size
>
0
:
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
embeddings
=
embeddings
+
token_type_embeddings
return
embeddings
class
VocabParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
*
args
,
process_group
=
None
,
padding_idx
=
None
,
**
kwargs
):
self
.
process_group
=
process_group
if
process_group
is
not
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
num_embeddings
%
world_size
!=
0
:
raise
ValueError
(
f
"num_embeddings (
{
num_embeddings
}
) must be divisible by "
f
"world_size (
{
world_size
}
)"
)
if
world_size
>
1
and
padding_idx
is
not
None
:
raise
RuntimeError
(
"ParallelEmbedding does not support padding_idx"
)
else
:
world_size
=
1
super
().
__init__
(
num_embeddings
//
world_size
,
*
args
,
padding_idx
=
padding_idx
,
**
kwargs
)
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
if
self
.
process_group
is
None
:
return
super
().
forward
(
input
)
else
:
rank
=
torch
.
distributed
.
get_rank
(
self
.
process_group
)
vocab_size
=
self
.
num_embeddings
vocab_start_index
,
vocab_end_index
=
rank
*
vocab_size
,
(
rank
+
1
)
*
vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
input_ids_mask
=
(
input
<
vocab_start_index
)
|
(
input
>=
vocab_end_index
)
input
=
input
-
vocab_start_index
input
[
input_ids_mask
]
=
0
embeddings
=
super
().
forward
(
input
)
embeddings
[
input_ids_mask
]
=
0.0
return
embeddings
class
ColumnParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
*
args
,
process_group
=
None
,
**
kwargs
):
self
.
process_group
=
process_group
if
process_group
is
not
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
embedding_dim
%
world_size
!=
0
:
raise
ValueError
(
f
"embedding_dim (
{
embedding_dim
}
) must be divisible by "
f
"world_size (
{
world_size
}
)"
)
else
:
world_size
=
1
super
().
__init__
(
num_embeddings
,
embedding_dim
//
world_size
,
*
args
,
**
kwargs
)
class
ParallelGPT2Embeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
process_group
,
padding_idx
=
None
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
word_embeddings
=
VocabParallelEmbedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
process_group
=
process_group
,
**
factory_kwargs
,
)
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
ColumnParallelEmbedding
(
max_position_embeddings
,
embed_dim
,
process_group
=
process_group
,
**
factory_kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
combine_batch_seqlen_dim
=
False
):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size
,
seqlen
=
input_ids
.
shape
world_size
=
torch
.
distributed
.
get_world_size
(
self
.
process_group
)
embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
max_position_embeddings
>
0
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
if
world_size
<=
1
:
embeddings
=
embeddings
+
position_embeddings
else
:
partition_dim
=
self
.
position_embeddings
.
embedding_dim
rank
=
torch
.
distributed
.
get_rank
(
self
.
process_group
)
embeddings
[
...,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
]
+=
position_embeddings
if
combine_batch_seqlen_dim
:
embeddings
=
rearrange
(
embeddings
,
"b s d -> (b s) d"
)
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
return
embeddings
if
world_size
<=
1
else
reduce_fn
(
embeddings
,
self
.
process_group
)
flash_attn/modules/mha.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2023, Tri Dao.
import
math
from
functools
import
partial
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
,
repeat
from
flash_attn.utils.distributed
import
get_dim_for_local_rank
try
:
from
flash_attn
import
(
flash_attn_kvpacked_func
,
flash_attn_qkvpacked_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_with_kvcache
,
)
except
ImportError
:
flash_attn_varlen_qkvpacked_func
,
flash_attn_varlen_kvpacked_func
=
None
,
None
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
=
None
,
None
flash_attn_with_kvcache
=
None
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
FusedDense
,
RowParallelLinear
except
ImportError
:
FusedDense
,
ColumnParallelLinear
,
RowParallelLinear
=
None
,
None
,
None
try
:
from
flash_attn.layers.rotary
import
RotaryEmbedding
except
ImportError
:
RotaryEmbedding
=
None
# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
def
get_alibi_slopes
(
nheads
):
def
get_slopes_power_of_2
(
nheads
):
start
=
2
**
(
-
(
2
**
-
(
math
.
log2
(
nheads
)
-
3
)))
ratio
=
start
return
[
start
*
ratio
**
i
for
i
in
range
(
nheads
)]
if
math
.
log2
(
nheads
).
is_integer
():
return
get_slopes_power_of_2
(
nheads
)
else
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
nheads
))
return
(
get_slopes_power_of_2
(
closest_power_of_2
)
+
get_alibi_slopes
(
2
*
closest_power_of_2
)[
0
::
2
][:
nheads
-
closest_power_of_2
]
)
class
FlashSelfAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
window_size
=
(
-
1
,
-
1
),
alibi_slopes
=
None
,
deterministic
=
False
,
):
super
().
__init__
()
assert
flash_attn_varlen_qkvpacked_func
is
not
None
,
"FlashAttention is not installed"
assert
flash_attn_qkvpacked_func
is
not
None
,
"FlashAttention is not installed"
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
self
.
register_buffer
(
"alibi_slopes"
,
alibi_slopes
,
persistent
=
False
)
self
.
window_size
=
window_size
self
.
deterministic
=
deterministic
def
forward
(
self
,
qkv
,
causal
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value.
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
Returns:
--------
out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
else (B, S, H, D).
"""
assert
qkv
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
qkv
.
is_cuda
causal
=
self
.
causal
if
causal
is
None
else
causal
unpadded
=
cu_seqlens
is
not
None
if
self
.
alibi_slopes
is
not
None
:
self
.
alibi_slopes
=
self
.
alibi_slopes
.
to
(
torch
.
float32
)
if
unpadded
:
assert
cu_seqlens
.
dtype
==
torch
.
int32
assert
max_seqlen
is
not
None
assert
isinstance
(
max_seqlen
,
int
)
return
flash_attn_varlen_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
window_size
,
deterministic
=
self
.
deterministic
,
)
else
:
return
flash_attn_qkvpacked_func
(
qkv
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
window_size
,
deterministic
=
self
.
deterministic
,
)
class
FlashCrossAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
alibi_slopes
=
None
,
window_size
=
(
-
1
,
-
1
),
deterministic
=
False
,
):
super
().
__init__
()
assert
flash_attn_varlen_kvpacked_func
is
not
None
,
"FlashAttention is not installed"
assert
flash_attn_kvpacked_func
is
not
None
,
"FlashAttention is not installed"
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
self
.
register_buffer
(
"alibi_slopes"
,
alibi_slopes
,
persistent
=
False
)
self
.
window_size
=
window_size
self
.
deterministic
=
deterministic
def
forward
(
self
,
q
,
kv
,
causal
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
cu_seqlens_k
=
None
,
max_seqlen_k
=
None
,
):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
max_seqlen: int. Maximum sequence length in the batch of q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_k: int. Maximum sequence length in the batch of k and v.
"""
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
q
.
is_cuda
and
kv
.
is_cuda
causal
=
self
.
causal
if
causal
is
None
else
causal
unpadded
=
cu_seqlens
is
not
None
if
self
.
alibi_slopes
is
not
None
:
self
.
alibi_slopes
=
self
.
alibi_slopes
.
to
(
torch
.
float32
)
if
unpadded
:
assert
cu_seqlens
.
dtype
==
torch
.
int32
assert
max_seqlen
is
not
None
assert
isinstance
(
max_seqlen
,
int
)
assert
cu_seqlens_k
is
not
None
assert
cu_seqlens_k
.
dtype
==
torch
.
int32
assert
max_seqlen_k
is
not
None
assert
isinstance
(
max_seqlen
,
int
)
return
flash_attn_varlen_kvpacked_func
(
q
,
kv
,
cu_seqlens
,
cu_seqlens_k
,
max_seqlen
,
max_seqlen_k
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
window_size
,
deterministic
=
self
.
deterministic
,
)
else
:
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
seqlen_k
=
kv
.
shape
[
1
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
return
flash_attn_kvpacked_func
(
q
,
kv
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
causal
=
causal
,
softmax_scale
=
self
.
softmax_scale
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
window_size
,
deterministic
=
self
.
deterministic
,
)
class
SelfAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
super
().
__init__
()
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
def
forward
(
self
,
qkv
,
causal
=
None
,
key_padding_mask
=
None
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, S)
"""
batch_size
,
seqlen
=
qkv
.
shape
[
0
],
qkv
.
shape
[
1
]
causal
=
self
.
causal
if
causal
is
None
else
causal
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
softmax_scale
=
self
.
softmax_scale
or
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
*
softmax_scale
)
if
key_padding_mask
is
not
None
:
padding_mask
=
torch
.
full
(
(
batch_size
,
seqlen
),
-
10000.0
,
dtype
=
scores
.
dtype
,
device
=
scores
.
device
)
padding_mask
.
masked_fill_
(
key_padding_mask
,
0.0
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
rearrange
(
padding_mask
,
"b s -> b 1 1 s"
)
if
causal
:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen
,
seqlen
),
-
10000.0
,
device
=
scores
.
device
),
1
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
causal_mask
.
to
(
dtype
=
scores
.
dtype
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
v
.
dtype
)
attention_drop
=
self
.
drop
(
attention
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention_drop
,
v
)
return
output
class
CrossAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
super
().
__init__
()
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
def
forward
(
self
,
q
,
kv
,
causal
=
None
,
key_padding_mask
=
None
):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk)
"""
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
causal
=
self
.
causal
if
causal
is
None
else
causal
seqlen_k
=
kv
.
shape
[
1
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
if
kv
.
shape
[
3
]
!=
q
.
shape
[
2
]:
# MQA/GQA
kv
=
repeat
(
kv
,
"... hkv d -> ... (hkv g) d"
,
g
=
q
.
shape
[
2
]
//
kv
.
shape
[
3
])
k
,
v
=
kv
.
unbind
(
dim
=
2
)
softmax_scale
=
self
.
softmax_scale
or
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
*
softmax_scale
)
if
key_padding_mask
is
not
None
:
padding_mask
=
torch
.
full
(
(
batch_size
,
seqlen_k
),
-
10000.0
,
dtype
=
scores
.
dtype
,
device
=
scores
.
device
)
padding_mask
.
masked_fill_
(
key_padding_mask
,
0.0
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
rearrange
(
padding_mask
,
"b s -> b 1 1 s"
)
if
causal
:
# causal mask needs to take into account the difference between seqlen_q and seqlen_k
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
q
.
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
kv
.
device
,
dtype
=
torch
.
long
)
sk
=
(
seqlen_k
if
key_padding_mask
is
None
else
rearrange
(
key_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
causal_mask
=
col_idx
>
row_idx
+
sk
-
seqlen_q
scores
=
scores
.
masked_fill
(
causal_mask
,
-
10000.0
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
v
.
dtype
)
attention_drop
=
self
.
drop
(
attention
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention_drop
,
v
)
return
output
class
LinearResidual
(
nn
.
Linear
):
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
super
().
forward
(
input
),
input
def
_update_kv_cache
(
kv
,
inference_params
,
layer_idx
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
# Pre-allocate memory for key-values for inference.
num_heads
,
head_dim
=
kv
.
shape
[
-
2
:]
if
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
kv_cache
=
torch
.
empty
(
inference_params
.
max_batch_size
,
inference_params
.
max_seqlen
,
2
,
num_heads
,
head_dim
,
dtype
=
kv
.
dtype
,
device
=
kv
.
device
,
)
inference_params
.
key_value_memory_dict
[
layer_idx
]
=
kv_cache
else
:
kv_cache
=
inference_params
.
key_value_memory_dict
[
layer_idx
]
# Adjust key and value for inference
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
kv
.
shape
[
0
]
sequence_start
=
inference_params
.
seqlen_offset
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
assert
batch_end
<=
kv_cache
.
shape
[
0
]
assert
sequence_end
<=
kv_cache
.
shape
[
1
]
assert
kv_cache
is
not
None
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
return
kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
class
MHA
(
nn
.
Module
):
"""Multi-head self-attention and cross-attention"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
num_heads_kv
=
None
,
cross_attn
=
False
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
dwconv
=
False
,
rotary_emb_dim
=
0
,
rotary_emb_base
=
10000.0
,
rotary_emb_scale_base
=
None
,
rotary_emb_interleaved
=
False
,
use_alibi
=
False
,
window_size
=
(
-
1
,
-
1
),
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
return_residual
=
False
,
checkpointing
=
False
,
device
=
None
,
dtype
=
None
,
)
->
None
:
"""
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
cross_attn
=
cross_attn
self
.
causal
=
causal
self
.
layer_idx
=
layer_idx
self
.
dwconv
=
dwconv
self
.
rotary_emb_dim
=
rotary_emb_dim
self
.
use_flash_attn
=
use_flash_attn
self
.
return_residual
=
return_residual
self
.
checkpointing
=
checkpointing
if
use_alibi
:
assert
use_flash_attn
,
"ALiBi code path requires flash_attn"
alibi_slopes
=
torch
.
tensor
(
get_alibi_slopes
(
num_heads
),
device
=
device
)
else
:
alibi_slopes
=
None
if
window_size
!=
(
-
1
,
-
1
):
assert
use_flash_attn
,
"Local (sliding window) attention code path requires flash_attn"
self
.
num_heads
=
num_heads
self
.
num_heads_kv
=
num_heads_kv
if
num_heads_kv
is
not
None
else
num_heads
assert
(
self
.
num_heads
%
self
.
num_heads_kv
==
0
),
"num_heads must be divisible by num_heads_kv"
assert
self
.
embed_dim
%
num_heads
==
0
,
"embed_dim must be divisible by num_heads"
self
.
head_dim
=
self
.
embed_dim
//
num_heads
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
kv_dim
=
2
*
self
.
head_dim
*
self
.
num_heads_kv
if
self
.
rotary_emb_dim
>
0
:
assert
not
cross_attn
,
"MHA with rotary embedding does not support cross-attention yet"
assert
RotaryEmbedding
is
not
None
,
"rotary_emb is not installed"
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
base
=
rotary_emb_base
,
scale_base
=
rotary_emb_scale_base
,
interleaved
=
rotary_emb_interleaved
,
device
=
device
,
)
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
"fused_dense is not installed"
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
linear_resid_cls
=
(
LinearResidual
if
not
fused_bias_fc
else
partial
(
FusedDense
,
return_residual
=
True
)
)
wqkv_cls
=
linear_cls
if
not
self
.
return_residual
else
linear_resid_cls
inner_attn_cls
=
(
partial
(
FlashSelfAttention
,
alibi_slopes
=
alibi_slopes
,
window_size
=
window_size
)
if
use_flash_attn
else
SelfAttention
)
inner_cross_attn_cls
=
(
partial
(
FlashCrossAttention
,
alibi_slopes
=
alibi_slopes
,
window_size
=
window_size
)
if
use_flash_attn
else
CrossAttention
)
if
not
self
.
cross_attn
:
self
.
Wqkv
=
wqkv_cls
(
embed_dim
,
qkv_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
else
:
self
.
Wq
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
self
.
Wkv
=
wqkv_cls
(
embed_dim
,
kv_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
if
self
.
dwconv
:
if
self
.
num_heads_kv
==
self
.
num_heads
:
self
.
dwconv_qkv
=
nn
.
Conv1d
(
qkv_dim
,
qkv_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
qkv_dim
)
else
:
self
.
dwconv_q
=
nn
.
Conv1d
(
embed_dim
,
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
embed_dim
)
self
.
dwconv_kv
=
nn
.
Conv1d
(
kv_dim
,
kv_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
kv_dim
)
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
,
)
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
out_proj_bias
,
**
factory_kwargs
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
):
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
device
=
self
.
out_proj
.
weight
.
device
return
torch
.
empty
(
batch_size
,
max_seqlen
,
2
,
self
.
num_heads_kv
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
,
)
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
assert
not
self
.
dwconv
,
"Generation does not support dwconv yet"
assert
self
.
layer_idx
is
not
None
,
"Generation requires layer_idx in the constructor"
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
def
_apply_rotary_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
assert
inference_params
is
not
None
and
inference_params
.
seqlen_offset
>
0
assert
self
.
use_flash_attn
if
self
.
rotary_emb_dim
>
0
:
assert
self
.
rotary_emb
.
scale
is
None
,
"This code path does not support xPos"
self
.
rotary_emb
.
_update_cos_sin_cache
(
inference_params
.
max_seqlen
,
device
=
q
.
device
,
dtype
=
q
.
dtype
)
rotary_cos
,
rotary_sin
=
self
.
rotary_emb
.
_cos_cached
,
self
.
rotary_emb
.
_sin_cached
else
:
rotary_cos
,
rotary_sin
=
None
,
None
batch
=
q
.
shape
[
0
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
][:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
)
alibi_slopes
=
getattr
(
self
.
inner_cross_attn
,
"alibi_slopes"
,
None
)
context
=
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
rotary_cos
=
rotary_cos
,
rotary_sin
=
rotary_sin
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
rotary_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
alibi_slopes
=
alibi_slopes
,
)
return
context
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention"""
if
(
inference_params
.
seqlen_offset
==
0
or
flash_attn_with_kvcache
is
None
or
not
self
.
use_flash_attn
):
# TODO: this only uses seqlen_offset and not lengths_per_sample.
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
return
self
.
inner_cross_attn
(
q
,
kv
)
else
:
batch
=
q
.
shape
[
0
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
][:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
)
alibi_slopes
=
getattr
(
self
.
inner_cross_attn
,
"alibi_slopes"
,
None
)
return
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
alibi_slopes
=
alibi_slopes
,
)
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
mixer_subset
=
None
,
inference_params
=
None
,
**
kwargs
,
):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
is the is the sum of the sequence lengths in the batch.
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into x. Only applicable when using
FlashAttention.
max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
if
cu_seqlens
is
not
None
:
assert
max_seqlen
is
not
None
assert
key_padding_mask
is
None
assert
self
.
use_flash_attn
assert
not
self
.
dwconv
assert
self
.
rotary_emb_dim
==
0
if
key_padding_mask
is
not
None
:
assert
cu_seqlens
is
None
assert
max_seqlen
is
None
assert
not
self
.
use_flash_attn
if
inference_params
is
not
None
:
assert
key_padding_mask
is
None
assert
cu_seqlens
is
None
and
max_seqlen
is
None
assert
not
self
.
dwconv
kwargs
=
(
{
"cu_seqlens"
:
cu_seqlens
,
"max_seqlen"
:
max_seqlen
,
**
kwargs
}
if
self
.
use_flash_attn
else
{
"key_padding_mask"
:
key_padding_mask
,
**
kwargs
}
)
seqlen_offset
=
(
0
if
inference_params
is
None
else
(
inference_params
.
lengths_per_sample
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
)
)
rotary_max_seqlen
=
inference_params
.
max_seqlen
if
inference_params
is
not
None
else
None
batch
,
seqlen
=
x
.
shape
[:
2
]
if
not
self
.
cross_attn
and
self
.
num_heads_kv
==
self
.
num_heads
:
assert
x_kv
is
None
and
mixer_subset
is
None
if
not
self
.
return_residual
:
qkv
=
self
.
Wqkv
(
x
)
else
:
qkv
,
x
=
self
.
Wqkv
(
x
)
if
self
.
dwconv
:
qkv
=
rearrange
(
self
.
dwconv_qkv
(
rearrange
(
qkv
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
).
contiguous
()
qkv
=
rearrange
(
qkv
,
"... (three h d) -> ... three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
if
(
inference_params
is
None
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
seqlen_offset
,
max_seqlen
=
rotary_max_seqlen
)
if
inference_params
is
None
:
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
context
=
self
.
_update_kvcache_attention
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
else
:
context
=
self
.
_apply_rotary_update_kvcache_attention
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
else
:
if
self
.
cross_attn
:
if
not
self
.
return_residual
:
q
=
self
.
Wq
(
x
if
mixer_subset
is
None
else
x
[:,
mixer_subset
])
kv
=
self
.
Wkv
(
x_kv
if
x_kv
is
not
None
else
x
)
else
:
if
x_kv
is
not
None
:
kv
,
x_kv
=
self
.
Wkv
(
x_kv
)
else
:
kv
,
x
=
self
.
Wkv
(
x
)
q
=
self
.
Wq
(
x
if
mixer_subset
is
None
else
x
[:,
mixer_subset
])
else
:
assert
self
.
num_heads_kv
!=
self
.
num_heads
if
not
self
.
return_residual
:
qkv
=
self
.
Wqkv
(
x
)
else
:
qkv
,
x
=
self
.
Wqkv
(
x
)
q
=
qkv
[...,
:
self
.
num_heads
*
self
.
head_dim
]
kv
=
qkv
[...,
self
.
num_heads
*
self
.
head_dim
:]
q
=
rearrange
(
q
,
"... (h d) -> ... h d"
,
d
=
self
.
head_dim
)
kv
=
rearrange
(
kv
,
"... (two hkv d) -> ... two hkv d"
,
two
=
2
,
d
=
self
.
head_dim
)
if
self
.
dwconv
:
q
=
rearrange
(
self
.
dwconv_q
(
rearrange
(
q
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
).
contiguous
()
kv
=
rearrange
(
self
.
dwconv_kv
(
rearrange
(
kv
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
).
contiguous
()
if
(
inference_params
is
None
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
if
self
.
rotary_emb_dim
>
0
:
q
,
kv
=
self
.
rotary_emb
(
q
,
kv
,
seqlen_offset
=
seqlen_offset
,
max_seqlen
=
rotary_max_seqlen
)
if
inference_params
is
None
:
if
not
self
.
checkpointing
:
context
=
self
.
inner_cross_attn
(
q
,
kv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_cross_attn
,
q
,
kv
,
**
kwargs
)
else
:
context
=
self
.
_update_kvcache_attention
(
q
,
kv
,
inference_params
)
else
:
context
=
self
.
_apply_rotary_update_kvcache_attention
(
q
,
kv
,
inference_params
)
out
=
self
.
out_proj
(
rearrange
(
context
,
"... h d -> ... (h d)"
))
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
class
ParallelMHA
(
nn
.
Module
):
"""Multi-head self-attention and cross-attention"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
process_group
,
num_heads_kv
=
None
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
rotary_emb_dim
=
0
,
rotary_emb_base
=
10000.0
,
rotary_emb_scale_base
=
None
,
rotary_emb_interleaved
=
False
,
use_alibi
=
False
,
window_size
=
(
-
1
,
-
1
),
use_flash_attn
=
False
,
checkpointing
=
False
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
,
)
->
None
:
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
causal
=
causal
self
.
layer_idx
=
layer_idx
self
.
rotary_emb_dim
=
rotary_emb_dim
self
.
use_flash_attn
=
use_flash_attn
self
.
checkpointing
=
checkpointing
self
.
process_group
=
process_group
self
.
world_size
=
process_group
.
size
()
self
.
local_rank
=
torch
.
distributed
.
get_rank
(
process_group
)
self
.
num_heads
=
num_heads
assert
self
.
embed_dim
%
self
.
num_heads
==
0
,
"embed_dim must be divisible by num_heads"
self
.
num_heads_kv
=
num_heads_kv
if
num_heads_kv
is
not
None
else
num_heads
assert
(
self
.
num_heads
%
self
.
num_heads_kv
==
0
),
"num_heads must be divisible by num_heads_kv"
self
.
num_heads_per_rank
=
get_dim_for_local_rank
(
self
.
num_heads
,
self
.
world_size
,
self
.
local_rank
)
self
.
num_heads_kv_per_rank
=
get_dim_for_local_rank
(
self
.
num_heads_kv
,
self
.
world_size
,
self
.
local_rank
)
self
.
head_dim
=
self
.
embed_dim
//
num_heads
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
if
use_alibi
:
assert
use_flash_attn
,
"ALiBi code path requires flash_attn"
num_heads_local
=
math
.
ceil
(
self
.
num_heads
/
self
.
world_size
)
alibi_slopes
=
torch
.
tensor
(
get_alibi_slopes
(
num_heads
)[
self
.
local_rank
*
num_heads_local
:
(
self
.
local_rank
+
1
)
*
num_heads_local
],
device
=
device
,
)
else
:
alibi_slopes
=
None
if
window_size
!=
(
-
1
,
-
1
):
assert
use_flash_attn
,
"Local (sliding window) attention code path requires flash_attn"
if
self
.
rotary_emb_dim
>
0
:
assert
RotaryEmbedding
is
not
None
,
"rotary_emb is not installed"
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
base
=
rotary_emb_base
,
scale_base
=
rotary_emb_scale_base
,
interleaved
=
rotary_emb_interleaved
,
device
=
device
,
)
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
raise
ImportError
(
"fused_dense is not installed"
)
self
.
Wqkv
=
ColumnParallelLinear
(
embed_dim
,
qkv_dim
,
process_group
,
bias
=
qkv_proj_bias
,
sequence_parallel
=
sequence_parallel
,
multiple_of
=
self
.
head_dim
*
(
self
.
num_heads
//
self
.
num_heads_kv
+
2
),
**
factory_kwargs
,
)
inner_attn_cls
=
(
partial
(
FlashSelfAttention
,
alibi_slopes
=
alibi_slopes
,
window_size
=
window_size
)
if
use_flash_attn
else
SelfAttention
)
inner_cross_attn_cls
=
(
partial
(
FlashCrossAttention
,
alibi_slopes
=
alibi_slopes
,
window_size
=
window_size
)
if
use_flash_attn
else
CrossAttention
)
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
process_group
,
bias
=
out_proj_bias
,
sequence_parallel
=
sequence_parallel
,
multiple_of
=
self
.
head_dim
,
**
factory_kwargs
,
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
):
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
device
=
self
.
out_proj
.
weight
.
device
return
torch
.
empty
(
batch_size
,
max_seqlen
,
2
,
self
.
num_heads_kv_per_rank
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
,
)
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
assert
self
.
layer_idx
is
not
None
,
"Generation requires layer_idx in the constructor"
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
def
_apply_rotary_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
assert
inference_params
is
not
None
and
inference_params
.
seqlen_offset
>
0
assert
self
.
use_flash_attn
if
self
.
rotary_emb_dim
>
0
:
assert
self
.
rotary_emb
.
scale
is
None
,
"This code path does not support xPos"
self
.
rotary_emb
.
_update_cos_sin_cache
(
inference_params
.
max_seqlen
,
device
=
q
.
device
,
dtype
=
q
.
dtype
)
rotary_cos
,
rotary_sin
=
self
.
rotary_emb
.
_cos_cached
,
self
.
rotary_emb
.
_sin_cached
else
:
rotary_cos
,
rotary_sin
=
None
,
None
batch
=
q
.
shape
[
0
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
][:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
)
alibi_slopes
=
getattr
(
self
.
inner_cross_attn
,
"alibi_slopes"
,
None
)
context
=
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
rotary_cos
=
rotary_cos
,
rotary_sin
=
rotary_sin
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
rotary_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
alibi_slopes
=
alibi_slopes
,
)
return
context
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention"""
if
inference_params
.
seqlen_offset
==
0
or
not
self
.
use_flash_attn
:
# TODO: this only uses seqlen_offset and not lengths_per_sample.
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
return
self
.
inner_cross_attn
(
q
,
kv
)
else
:
batch
=
q
.
shape
[
0
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
][:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
)
alibi_slopes
=
getattr
(
self
.
inner_cross_attn
,
"alibi_slopes"
,
None
)
context
=
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
alibi_slopes
=
alibi_slopes
,
)
return
context
def
forward
(
self
,
x
,
seqlen
=
None
,
inference_params
=
None
,
**
kwargs
):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
split x during sequence parallel, we split the batch * seqlen dimension
(in case batch is small).
"""
qkv
=
self
.
Wqkv
(
x
)
if
seqlen
is
not
None
:
qkv
=
rearrange
(
qkv
,
"(b s) ... -> b s ..."
,
s
=
seqlen
)
seqlen_offset
=
(
0
if
inference_params
is
None
else
(
inference_params
.
lengths_per_sample
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
)
)
rotary_max_seqlen
=
inference_params
.
max_seqlen
if
inference_params
is
not
None
else
None
if
self
.
num_heads_kv
==
self
.
num_heads
:
qkv
=
rearrange
(
qkv
,
"b s (three h d) -> b s three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
if
(
inference_params
is
None
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
seqlen_offset
,
max_seqlen
=
rotary_max_seqlen
)
if
inference_params
is
None
:
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
context
=
self
.
_update_kvcache_attention
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
else
:
context
=
self
.
_apply_rotary_update_kvcache_attention
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
else
:
q
=
rearrange
(
qkv
[...,
:
self
.
num_heads_per_rank
*
self
.
head_dim
],
"... (h d) -> ... h d"
,
d
=
self
.
head_dim
,
)
kv
=
rearrange
(
qkv
[...,
self
.
num_heads_per_rank
*
self
.
head_dim
:],
"... (two hkv d) -> ... two hkv d"
,
two
=
2
,
d
=
self
.
head_dim
,
)
if
(
inference_params
is
None
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
if
self
.
rotary_emb_dim
>
0
:
q
,
kv
=
self
.
rotary_emb
(
q
,
kv
,
seqlen_offset
=
seqlen_offset
,
max_seqlen
=
rotary_max_seqlen
)
if
inference_params
is
None
:
if
not
self
.
checkpointing
:
context
=
self
.
inner_cross_attn
(
q
,
kv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_cross_attn
,
q
,
kv
,
**
kwargs
)
else
:
context
=
self
.
_update_kvcache_attention
(
q
,
kv
,
inference_params
)
else
:
context
=
self
.
_apply_rotary_update_kvcache_attention
(
q
,
kv
,
inference_params
)
context
=
rearrange
(
context
,
"b s h d -> b s (h d)"
)
if
seqlen
is
not
None
:
context
=
rearrange
(
context
,
"b s d -> (b s) d"
)
out
=
self
.
out_proj
(
context
)
return
out
flash_attn/modules/mlp.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2023, Tri Dao.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.distributed
import
ProcessGroup
try
:
from
flash_attn.ops.activations
import
swiglu
except
ImportError
:
swiglu
=
None
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
RowParallelLinear
except
ImportError
:
ColumnParallelLinear
,
RowParallelLinear
=
None
,
None
try
:
from
flash_attn.ops.fused_dense
import
FusedMLP
,
ParallelFusedMLP
except
ImportError
:
FusedMLP
,
ParallelFusedMLP
=
None
,
None
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
gelu
,
bias1
=
True
,
bias2
=
True
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
hidden_features
if
hidden_features
is
not
None
else
in_features
*
4
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
activation
=
activation
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
y
=
self
.
activation
(
y
)
y
=
self
.
fc2
(
y
)
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
class
ParallelMLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
gelu
,
process_group
:
ProcessGroup
=
None
,
sequence_parallel
=
True
,
bias1
=
True
,
bias2
=
True
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
assert
ColumnParallelLinear
is
not
None
,
"Need to install fused_dense"
assert
RowParallelLinear
is
not
None
,
"Need to install fused_dense"
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
hidden_features
if
hidden_features
is
not
None
else
in_features
*
4
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
process_group
,
bias
=
bias1
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
self
.
activation
=
activation
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
y
=
self
.
activation
(
y
)
y
=
self
.
fc2
(
y
)
return
y
class
GatedMlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
sigmoid
,
bias1
=
True
,
bias2
=
True
,
multiple_of
=
128
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
(
hidden_features
if
hidden_features
is
not
None
else
int
(
8
*
in_features
/
3
)
)
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
2
*
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
activation
=
activation
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
if
self
.
activation
==
F
.
sigmoid
:
# Special case for GLU
y
=
F
.
glu
(
y
,
dim
=-
1
)
elif
self
.
activation
==
F
.
silu
and
swiglu
is
not
None
:
# Special case for SwiGLU
y
,
gate
=
y
.
chunk
(
2
,
dim
=-
1
)
y
=
swiglu
(
gate
,
y
)
else
:
y
,
gate
=
y
.
chunk
(
2
,
dim
=-
1
)
y
=
y
*
self
.
activation
(
gate
)
y
=
self
.
fc2
(
y
)
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
class
ParallelGatedMlp
(
nn
.
Module
):
"""Parallel GatedMlp"""
def
__init__
(
self
,
in_features
,
process_group
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
sigmoid
,
bias1
=
True
,
bias2
=
True
,
multiple_of
=
128
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
(
hidden_features
if
hidden_features
is
not
None
else
int
(
8
*
in_features
/
3
)
)
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
raise
ImportError
(
"fused_dense is not installed"
)
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
2
*
hidden_features
,
process_group
,
bias
=
bias1
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
self
.
activation
=
activation
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
if
self
.
activation
==
F
.
sigmoid
:
# Special case for GLU
y
=
F
.
glu
(
y
,
dim
=-
1
)
else
:
y
,
gate
=
y
.
chunk
(
2
,
dim
=-
1
)
y
=
y
*
self
.
activation
(
gate
)
y
=
self
.
fc2
(
y
)
return
y
flash_attn/ops/__init__.py
deleted
100644 → 0
View file @
5018ac6a
flash_attn/ops/activations.py
deleted
100644 → 0
View file @
5018ac6a
# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@
torch
.
jit
.
script
def
bias_gelu
(
y
,
bias
):
x
=
bias
+
y
return
(
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))).
to
(
dtype
=
y
.
dtype
)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
torch
.
jit
.
script
def
bias_gelu_back
(
g
,
y
,
bias
):
"""Assume that y has shape (B, D) and bias has shape (D)"""
x
=
bias
+
y
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
grad_y
=
ff
*
g
return
grad_y
.
to
(
dtype
=
y
.
dtype
),
grad_y
.
sum
(
dim
=
(
0
),
dtype
=
bias
.
dtype
)
class
GeLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
bias
):
ctx
.
save_for_backward
(
input
,
bias
)
return
bias_gelu
(
input
,
bias
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
bias
=
ctx
.
saved_tensors
tmp
=
bias_gelu_back
(
grad_output
,
input
,
bias
)
return
tmp
,
tmp
bias_gelu_impl
=
GeLUFunction
.
apply
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@
torch
.
jit
.
script
def
gelu_fwd
(
x
):
return
(
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))).
to
(
dtype
=
x
.
dtype
)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
torch
.
jit
.
script
def
gelu_bwd
(
g
,
x
):
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
return
(
ff
*
g
).
to
(
dtype
=
x
.
dtype
)
class
FastGeLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
):
ctx
.
save_for_backward
(
input
)
return
gelu_fwd
(
input
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
(
input
,)
=
ctx
.
saved_tensors
tmp
=
gelu_bwd
(
grad_output
,
input
)
return
tmp
fast_gelu_impl
=
FastGeLUFunction
.
apply
@
torch
.
jit
.
script
def
relu_bwd
(
g
,
x
):
return
torch
.
where
(
x
>=
0
,
g
,
0.0
).
to
(
dtype
=
x
.
dtype
)
@
torch
.
jit
.
script
def
sqrelu_fwd
(
x
):
r
=
F
.
relu
(
x
)
return
(
r
*
r
).
to
(
dtype
=
x
.
dtype
)
@
torch
.
jit
.
script
def
sqrelu_bwd
(
g
,
x
):
return
(
2.0
*
g
*
F
.
relu
(
x
)).
to
(
dtype
=
x
.
dtype
)
swiglu_fwd_codestring
=
"""
template <typename T> T swiglu_fwd(T x, T y) {
return float(x) * float(y) / (1.0f + ::exp(-float(x)));
}
"""
swiglu_bwd_codestring
=
"""
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
dy = float(x) * x_sigmoid * float(g);
}
"""
swiglu_fwd
=
torch
.
cuda
.
jiterator
.
_create_jit_fn
(
swiglu_fwd_codestring
)
swiglu_bwd
=
torch
.
cuda
.
jiterator
.
_create_multi_output_jit_fn
(
swiglu_bwd_codestring
,
num_outputs
=
2
)
class
SwiGLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
y
):
ctx
.
save_for_backward
(
x
,
y
)
return
swiglu_fwd
(
x
,
y
)
@
staticmethod
def
backward
(
ctx
,
dout
):
x
,
y
=
ctx
.
saved_tensors
return
swiglu_bwd
(
x
,
y
,
dout
)
swiglu
=
SwiGLUFunction
.
apply
flash_attn/ops/fused_dense.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2023, Tri Dao.
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# We make it work with pytorch amp and with bfloat16.
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
from
functools
import
partial
from
typing
import
Optional
# import fused_dense_cuda # from apex
import
fused_dense_lib
as
fused_dense_cuda
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.distributed
import
ProcessGroup
from
flash_attn.ops.activations
import
gelu_bwd
,
relu_bwd
,
sqrelu_bwd
,
sqrelu_fwd
from
flash_attn.utils.distributed
import
(
all_gather_raw
,
all_reduce
,
all_reduce_raw
,
reduce_scatter
,
reduce_scatter_raw
,
)
class
FusedDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight
,
bias
,
return_residual
=
False
,
process_group
=
None
,
sequence_parallel
=
True
):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
"""
ctx
.
compute_weight_gradient
=
weight
.
requires_grad
ctx
.
return_residual
=
return_residual
ctx
.
process_group
=
process_group
ctx
.
sequence_parallel
=
sequence_parallel
if
torch
.
is_autocast_enabled
():
x
=
x
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
x
=
x
.
contiguous
()
if
process_group
is
not
None
and
sequence_parallel
:
# We want to kick off the all_gather early, before weight dtype conversion
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
else
:
total_x
=
x
if
torch
.
is_autocast_enabled
():
weight
=
weight
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
bias
=
bias
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
if
bias
is
not
None
else
None
weight
=
weight
.
contiguous
()
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
batch_shape
,
n
=
total_x
.
shape
[:
-
1
],
total_x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if
min
(
batch_dim
,
n
,
*
weight
.
shape
)
>
65535
*
32
:
raise
RuntimeError
(
"fused_dense only supports matrix dims <= 2M"
)
output
=
F
.
linear
(
total_x
,
weight
,
bias
)
if
ctx
.
compute_weight_gradient
:
ctx
.
save_for_backward
(
x
,
weight
)
else
:
ctx
.
save_for_backward
(
weight
)
return
output
if
not
return_residual
else
(
output
,
x
)
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
,
*
args
):
grad_output
=
grad_output
.
contiguous
()
if
ctx
.
return_residual
:
(
grad_input
,)
=
args
grad_input
=
grad_input
.
contiguous
()
process_group
=
ctx
.
process_group
sequence_parallel
=
ctx
.
sequence_parallel
if
ctx
.
compute_weight_gradient
:
x
,
weight
=
ctx
.
saved_tensors
if
process_group
is
not
None
and
sequence_parallel
:
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
else
:
total_x
=
x
else
:
(
weight
,)
=
ctx
.
saved_tensors
total_x
=
None
batch_shape
=
grad_output
.
shape
[:
-
1
]
batch_dim
=
batch_shape
.
numel
()
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
if
ctx
.
needs_input_grad
[
0
]:
if
not
ctx
.
return_residual
:
grad_input
=
F
.
linear
(
grad_output
,
weight
.
t
())
else
:
grad_input
=
torch
.
addmm
(
grad_input
.
reshape
(
batch_dim
,
grad_input
.
shape
[
-
1
]),
grad_output
,
weight
)
grad_input
=
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
if
process_group
is
not
None
:
reduce_fn
=
reduce_scatter_raw
if
sequence_parallel
else
all_reduce_raw
grad_input
,
handle_grad_input
=
reduce_fn
(
grad_input
,
process_group
,
async_op
=
True
)
else
:
grad_input
=
None
if
ctx
.
needs_input_grad
[
1
]:
assert
ctx
.
compute_weight_gradient
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
grad_weight
,
grad_bias
=
fused_dense_cuda
.
linear_bias_wgrad
(
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
grad_output
,
ctx
.
needs_input_grad
[
2
]
)
else
:
grad_weight
=
None
grad_bias
=
grad_output
if
ctx
.
needs_input_grad
[
2
]
else
None
if
process_group
is
not
None
and
ctx
.
needs_input_grad
[
0
]:
handle_grad_input
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
def
fused_dense_func
(
x
:
Tensor
,
weight
:
Tensor
,
bias
:
Optional
[
Tensor
]
=
None
,
return_residual
:
bool
=
False
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
sequence_parallel
:
bool
=
True
,
):
dtype_eligible
=
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()
)
if
x
.
is_cuda
and
weight
.
is_cuda
and
(
bias
is
None
or
bias
.
is_cuda
)
and
dtype_eligible
:
return
FusedDenseFunc
.
apply
(
x
,
weight
,
bias
,
return_residual
,
process_group
,
sequence_parallel
)
else
:
assert
process_group
is
None
out
=
F
.
linear
(
x
,
weight
,
bias
)
return
out
if
not
return_residual
else
(
out
,
x
)
class
FusedDense
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
return_residual
:
bool
=
False
,
device
=
None
,
dtype
=
None
,
)
->
None
:
super
().
__init__
(
in_features
,
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
self
.
return_residual
=
return_residual
def
forward
(
self
,
x
,
process_group
=
None
):
"""
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul.
"""
return
fused_dense_func
(
x
,
self
.
weight
,
self
.
bias
,
return_residual
=
self
.
return_residual
,
process_group
=
process_group
,
)
class
ColumnParallelLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
process_group
:
ProcessGroup
,
bias
:
bool
=
True
,
sequence_parallel
=
True
,
multiple_of
=
1
,
device
=
None
,
dtype
=
None
,
)
->
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
out_features
%
multiple_of
:
raise
ValueError
(
f
"out_features (
{
out_features
}
) must be a multiple of
{
multiple_of
}
"
)
multiple
=
out_features
//
multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div
=
multiple
//
world_size
mod
=
multiple
%
world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple
=
div
+
int
(
torch
.
distributed
.
get_rank
(
process_group
)
<
mod
)
super
().
__init__
(
in_features
,
local_multiple
*
multiple_of
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
def
forward
(
self
,
x
):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
return
fused_dense_func
(
x
,
self
.
weight
,
self
.
bias
,
process_group
=
self
.
process_group
,
sequence_parallel
=
self
.
sequence_parallel
,
)
class
RowParallelLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
process_group
:
ProcessGroup
,
bias
:
bool
=
True
,
sequence_parallel
=
True
,
multiple_of
=
1
,
device
=
None
,
dtype
=
None
,
)
->
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
rank
=
torch
.
distributed
.
get_rank
(
process_group
)
if
in_features
%
multiple_of
:
raise
ValueError
(
f
"in_features (
{
in_features
}
) must be a multiple of
{
multiple_of
}
"
)
multiple
=
in_features
//
multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div
=
multiple
//
world_size
mod
=
multiple
%
world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple
=
div
+
int
(
torch
.
distributed
.
get_rank
(
process_group
)
<
mod
)
# Only rank 0 will have bias
super
().
__init__
(
local_multiple
*
multiple_of
,
out_features
,
bias
=
bias
and
rank
==
0
,
device
=
device
,
dtype
=
dtype
,
)
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
def
forward
(
self
,
x
):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
a reduce_scatter of the result.
"""
out
=
fused_dense_func
(
x
,
self
.
weight
,
self
.
bias
)
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
return
reduce_fn
(
out
,
self
.
process_group
)
class
FusedMLPFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
activation
=
"gelu_approx"
,
save_pre_act
=
True
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
process_group
=
None
,
sequence_parallel
=
True
,
):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather of x before doing the matmul.
If sequence_parallel=False, then the input is already gathered.
checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out / relu_out in the bwd
2: recompute pre_act and gelu_out / relu_out in the bwd
"""
assert
-
1
<=
heuristic
<=
4
assert
activation
in
[
"gelu_approx"
,
"relu"
,
"sqrelu"
]
if
activation
==
"sqrelu"
:
assert
heuristic
==
-
1
if
not
save_pre_act
:
checkpoint_lvl
=
2
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
ctx
.
return_residual
=
return_residual
ctx
.
process_group
=
process_group
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
checkpoint_lvl
=
checkpoint_lvl
ctx
.
activation
=
activation
ctx
.
heuristic
=
heuristic
if
torch
.
is_autocast_enabled
():
x
=
x
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
x
=
x
.
contiguous
()
if
process_group
is
not
None
and
sequence_parallel
:
# We want to kick off the all_gather early, before weight dtype conversion
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
else
:
total_x
=
x
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
weight1
,
weight2
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
weight1
,
weight2
]]
bias1
=
bias1
.
to
(
dtype
=
dtype
)
if
bias1
is
not
None
else
None
bias2
=
bias2
.
to
(
dtype
=
dtype
)
if
bias2
is
not
None
else
None
weight1
=
weight1
.
contiguous
()
bias1
=
bias1
.
contiguous
()
if
bias1
is
not
None
else
None
weight2
=
weight2
.
contiguous
()
bias2
=
bias2
.
contiguous
()
if
bias2
is
not
None
else
None
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
batch_shape
,
n
=
total_x
.
shape
[:
-
1
],
total_x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if
min
(
batch_dim
,
n
,
*
weight1
.
shape
,
*
weight2
.
shape
)
>
65535
*
32
:
raise
RuntimeError
(
"fused_dense only supports matrix dims <= 2M"
)
if
heuristic
==
-
1
:
pre_act
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
"tanh"
)
if
activation
==
"gelu_approx"
else
(
sqrelu_fwd
if
activation
==
"sqrelu"
else
F
.
relu
)
)
with
torch
.
jit
.
fuser
(
"fuser2"
):
output1
=
activation_fn
(
pre_act
)
# This is before adding bias1
# pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
# with torch.jit.fuser('fuser2'):
# output1 = bias_gelu(pre_act, bias1)
else
:
is_gelu
=
activation
==
"gelu_approx"
output1
,
*
rest
=
fused_dense_cuda
.
linear_act_forward
(
total_x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
is_gelu
,
save_pre_act
,
heuristic
)
if
save_pre_act
:
pre_act
=
rest
[
0
]
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
if
checkpoint_lvl
==
0
or
(
checkpoint_lvl
==
1
and
activation
==
"relu"
):
# For RELU the pre_act is very small (just a bit-mask) so we just save it
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
pre_act
,
output1
)
elif
checkpoint_lvl
==
1
:
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
pre_act
)
elif
checkpoint_lvl
==
2
:
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
bias1
)
output2
=
output2
.
reshape
(
*
batch_shape
,
output2
.
shape
[
-
1
])
return
output2
if
not
return_residual
else
(
output2
,
x
)
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
,
*
args
):
grad_output
=
grad_output
.
contiguous
()
checkpoint_lvl
=
ctx
.
checkpoint_lvl
activation
=
ctx
.
activation
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
"tanh"
)
if
activation
==
"gelu_approx"
else
(
sqrelu_fwd
if
activation
==
"sqrelu"
else
F
.
relu
)
)
if
ctx
.
return_residual
:
(
grad_input
,)
=
args
grad_input
=
grad_input
.
contiguous
()
process_group
=
ctx
.
process_group
sequence_parallel
=
ctx
.
sequence_parallel
x
,
weight1
,
weight2
,
*
rest
=
ctx
.
saved_tensors
if
process_group
is
None
or
not
sequence_parallel
:
total_x
=
x
batch_shape
=
grad_output
.
shape
[:
-
1
]
batch_dim
=
batch_shape
.
numel
()
if
checkpoint_lvl
in
[
0
,
1
]:
if
process_group
is
not
None
and
sequence_parallel
:
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
if
checkpoint_lvl
==
0
or
(
checkpoint_lvl
==
1
and
activation
==
"relu"
):
pre_act
,
output1
=
rest
elif
checkpoint_lvl
==
1
:
(
pre_act
,)
=
rest
with
torch
.
jit
.
fuser
(
"fuser2"
):
output1
=
activation_fn
(
pre_act
)
elif
checkpoint_lvl
==
2
:
(
bias1
,)
=
rest
if
process_group
is
not
None
and
sequence_parallel
:
total_x
,
_
=
all_gather_raw
(
x
,
process_group
)
if
ctx
.
heuristic
==
-
1
:
pre_act
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
with
torch
.
jit
.
fuser
(
"fuser2"
):
output1
=
activation_fn
(
pre_act
)
else
:
output1
,
pre_act
=
fused_dense_cuda
.
linear_act_forward
(
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
weight1
,
bias1
,
activation
==
"gelu_approx"
,
True
,
ctx
.
heuristic
,
)
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
output1
=
output1
.
reshape
(
batch_dim
,
output1
.
shape
[
-
1
])
pre_act
=
pre_act
.
reshape
(
batch_dim
,
pre_act
.
shape
[
-
1
])
if
ctx
.
needs_input_grad
[
3
]:
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
,
ctx
.
needs_input_grad
[
4
]
)
else
:
grad_weight2
=
None
grad_bias2
=
grad_output
if
ctx
.
needs_input_grad
[
4
]
else
None
if
ctx
.
heuristic
==
-
1
:
# grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
grad_output1
=
F
.
linear
(
grad_output
,
weight2
.
t
())
activation_grad_fn
=
(
gelu_bwd
if
activation
==
"gelu_approx"
else
(
sqrelu_bwd
if
activation
==
"sqrelu"
else
relu_bwd
)
)
with
torch
.
jit
.
fuser
(
"fuser2"
):
grad_pre_act
=
activation_grad_fn
(
grad_output1
,
pre_act
)
else
:
# The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
# just compute gelu/relu grad
grad_pre_act
,
grad_bias1
=
fused_dense_cuda
.
bias_act_linear_dgrad_bgrad
(
weight2
,
grad_output
,
pre_act
,
activation
==
"gelu_approx"
,
ctx
.
heuristic
)
if
not
ctx
.
needs_input_grad
[
2
]:
grad_bias1
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
not
ctx
.
return_residual
:
grad_input
=
F
.
linear
(
grad_pre_act
,
weight1
.
t
())
else
:
grad_input
=
torch
.
addmm
(
grad_input
.
reshape
(
batch_dim
,
grad_input
.
shape
[
-
1
]),
grad_pre_act
,
weight1
)
grad_input
=
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
if
process_group
is
not
None
:
reduce_fn
=
reduce_scatter_raw
if
sequence_parallel
else
all_reduce_raw
grad_input
,
handle_grad_input
=
reduce_fn
(
grad_input
,
process_group
,
async_op
=
True
)
else
:
grad_input
=
None
if
ctx
.
heuristic
==
-
1
:
if
ctx
.
needs_input_grad
[
1
]:
if
process_group
is
not
None
and
sequence_parallel
and
checkpoint_lvl
!=
2
:
handle_x
.
wait
()
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_wgrad
(
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
grad_pre_act
,
ctx
.
needs_input_grad
[
2
],
)
else
:
grad_weight1
=
None
grad_bias1
=
grad_pre_act
if
ctx
.
needs_input_grad
[
2
]
else
None
else
:
if
ctx
.
needs_input_grad
[
1
]:
if
process_group
is
not
None
and
sequence_parallel
and
checkpoint_lvl
!=
2
:
handle_x
.
wait
()
grad_weight1
=
F
.
linear
(
grad_pre_act
.
t
(),
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]).
t
()
)
else
:
grad_weight1
=
None
if
process_group
is
not
None
and
ctx
.
needs_input_grad
[
0
]:
handle_grad_input
.
wait
()
return
(
grad_input
,
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
def
fused_mlp_func
(
x
:
Tensor
,
weight1
:
Tensor
,
weight2
:
Tensor
,
bias1
:
Optional
[
Tensor
]
=
None
,
bias2
:
Optional
[
Tensor
]
=
None
,
activation
:
str
=
"gelu_approx"
,
save_pre_act
:
bool
=
True
,
return_residual
:
bool
=
False
,
checkpoint_lvl
:
int
=
0
,
heuristic
:
int
=
0
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
sequence_parallel
:
bool
=
True
,
):
assert
activation
in
[
"gelu_approx"
,
"relu"
,
"sqrelu"
]
dtype_eligible
=
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()
)
# If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
dim_eligible
=
not
save_pre_act
or
(
x
.
shape
[
-
1
]
%
(
128
if
activation
==
"relu"
else
8
)
==
0
)
if
(
x
.
is_cuda
and
weight1
.
is_cuda
and
weight2
.
is_cuda
and
(
bias1
is
None
or
bias1
.
is_cuda
)
and
(
bias2
is
None
or
bias2
.
is_cuda
)
and
dtype_eligible
and
dim_eligible
):
return
FusedMLPFunc
.
apply
(
x
,
weight1
,
bias1
,
weight2
,
bias2
,
activation
,
save_pre_act
,
return_residual
,
checkpoint_lvl
,
heuristic
,
process_group
,
sequence_parallel
,
)
else
:
assert
process_group
is
None
pre_act
=
F
.
linear
(
x
,
weight1
,
bias1
)
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
"tanh"
)
if
activation
==
"gelu_approx"
else
partial
(
F
.
relu
,
inplace
=
True
)
)
output1
=
activation_fn
(
pre_act
)
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
return
output2
if
not
return_residual
else
(
output2
,
x
)
class
FusedMLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
bias1
=
True
,
bias2
=
True
,
activation
=
"gelu_approx"
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
"auto"
,
device
=
None
,
dtype
=
None
,
):
"""
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
Finally we do a reduce_scatter of the output.
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute pre_act and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
'auto': heuristic will be picked automatically:
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation
is slower than the unfused version.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
activation
in
[
"gelu_approx"
,
"relu"
,
"sqrelu"
]
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
*
4
self
.
activation
=
activation
self
.
return_residual
=
return_residual
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
heuristic
=
heuristic
if
activation
!=
"sqrelu"
else
-
1
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
process_group
=
None
):
dtype
=
x
.
dtype
if
not
torch
.
is_autocast_enabled
()
else
torch
.
get_autocast_gpu_dtype
()
if
self
.
heuristic
==
"auto"
:
if
self
.
activation
==
"gelu_approx"
:
if
torch
.
cuda
.
get_device_capability
(
"cuda"
)
==
(
9
,
0
):
heuristic
=
-
1
else
:
cuda_ver
=
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
"."
)))
heuristic
=
0
if
cuda_ver
>=
(
11
,
8
)
else
(
1
if
dtype
==
torch
.
float16
else
-
1
)
else
:
heuristic
=
0
else
:
heuristic
=
self
.
heuristic
out
=
fused_mlp_func
(
x
,
self
.
fc1
.
weight
,
self
.
fc2
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
bias
,
activation
=
self
.
activation
,
save_pre_act
=
self
.
training
,
return_residual
=
self
.
return_residual
,
checkpoint_lvl
=
self
.
checkpoint_lvl
,
heuristic
=
heuristic
,
process_group
=
process_group
,
)
if
self
.
return_residual
:
out
,
x
=
out
if
process_group
is
not
None
:
out
=
reduce_scatter
(
out
,
process_group
)
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
class
ParallelFusedMLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
"gelu_approx"
,
process_group
:
ProcessGroup
=
None
,
bias1
=
True
,
bias2
=
True
,
sequence_parallel
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
"auto"
,
device
=
None
,
dtype
=
None
,
):
"""
process_group is required. We're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
Finally we do a reduce_scatter of the output.
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute pre_act and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
'auto': heuristic will be picked automatically:
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
activation
in
[
"gelu_approx"
,
"relu"
,
"sqrelu"
]
assert
process_group
is
not
None
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
*
4
self
.
activation
=
activation
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
heuristic
=
heuristic
if
activation
!=
"sqrelu"
else
-
1
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
process_group
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
dtype
=
x
.
dtype
if
not
torch
.
is_autocast_enabled
()
else
torch
.
get_autocast_gpu_dtype
()
if
self
.
heuristic
==
"auto"
:
if
self
.
activation
==
"gelu_approx"
:
cuda_ver
=
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
"."
)))
heuristic
=
0
if
cuda_ver
>=
(
11
,
8
)
else
(
1
if
dtype
==
torch
.
float16
else
-
1
)
else
:
heuristic
=
0
else
:
heuristic
=
self
.
heuristic
out
=
fused_mlp_func
(
x
,
self
.
fc1
.
weight
,
self
.
fc2
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
bias
,
activation
=
self
.
activation
,
save_pre_act
=
self
.
training
,
checkpoint_lvl
=
self
.
checkpoint_lvl
,
heuristic
=
heuristic
,
process_group
=
self
.
process_group
,
sequence_parallel
=
self
.
sequence_parallel
,
)
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
return
reduce_fn
(
out
,
self
.
process_group
)
flash_attn/ops/layer_norm.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import
dropout_layer_norm
import
torch
from
torch.nn
import
init
def
maybe_align
(
x
,
alignment_in_bytes
=
16
):
"""Assume that x already has last dim divisible by alignment_in_bytes"""
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
return
x
if
x
.
data_ptr
()
%
alignment_in_bytes
==
0
else
x
.
clone
()
def
_dropout_add_layer_norm_forward
(
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size
=
gamma
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
residualmat
,
gamma
,
beta
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
epsilon
,
1.0
,
0
,
None
,
residual_in_fp32
,
is_rms_norm
,
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return
zmat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask
,
mu
,
rsigma
def
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
"""
hidden_size
=
gamma
.
numel
()
xmat
=
x
.
view
((
-
1
,
hidden_size
))
dzmat
=
dz
.
view
(
xmat
.
shape
)
dxmat
=
dx
.
view
(
xmat
.
shape
)
if
dx
is
not
None
else
None
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
if
x0
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
if
colscale
is
not
None
:
assert
x0
is
not
None
,
"x0 is required to compute the gradient of colscale"
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
1.0
,
0
,
has_residual
,
is_rms_norm
,
)
# dresidualmat is None if not has_residual
if
colscale
is
None
:
return
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
else
:
dcolscale
=
rest
[
0
]
return
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
dcolscale
def
_dropout_add_layer_norm_subset_forward
(
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size
=
gamma
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
x0_subset
=
x0_subset
.
view
(
-
1
)
if
x0_subset
is
not
None
else
None
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
residualmat
,
gamma
,
beta
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
None
,
residual_in_fp32
,
is_rms_norm
,
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return
zmat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask
,
mu
,
rsigma
def
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
"""
hidden_size
=
gamma
.
numel
()
xmat
=
x
.
view
((
-
1
,
hidden_size
))
dzmat
=
dz
.
view
(
-
1
,
hidden_size
)
dxmat
=
dx
.
view
(
xmat
.
shape
)
if
dx
is
not
None
else
None
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
if
x0
is
not
None
else
None
x0_subset
=
x0_subset
.
view
(
-
1
)
if
x0_subset
is
not
None
else
None
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
if
colscale
is
not
None
:
assert
x0
is
not
None
,
"x0 is required to compute the gradient of colscale"
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
,
is_rms_norm
,
)
# dresidualmat is None if not has_residual
if
colscale
is
None
:
return
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
else
:
dcolscale
=
rest
[
0
]
return
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
dcolscale
def
_dropout_add_layer_norm_parallel_residual_forward
(
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size
=
gamma0
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
x1mat
=
x1
.
view
((
-
1
,
hidden_size
))
if
x1
is
not
None
else
None
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
(
z0mat
,
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
,
)
=
dropout_layer_norm
.
dropout_add_ln_parallel_residual_fwd
(
x0mat
,
x1mat
,
residualmat
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
None
,
residual_in_fp32
,
is_rms_norm
,
)
# dmask0 and dmask1 are None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return
z0mat
,
z1mat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask0
,
dmask1
,
mu
,
rsigma
def
_dropout_add_layer_norm_parallel_residual_backward
(
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
"""
hidden_size
=
gamma0
.
numel
()
xmat
=
x
.
view
((
-
1
,
hidden_size
))
dz0mat
=
dz0
.
view
(
xmat
.
shape
)
dz1mat
=
dz1
.
view
(
xmat
.
shape
)
if
dz1
is
not
None
else
None
dxmat
=
dx
.
view
(
xmat
.
shape
)
if
dx
is
not
None
else
None
(
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
,
*
rest
,
)
=
dropout_layer_norm
.
dropout_add_ln_parallel_residual_bwd
(
dz0mat
,
dz1mat
,
dxmat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
,
)
# dresidualmat is None if not has_residual
return
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
class
DropoutAddLayerNormFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
,
):
x0
=
maybe_align
(
x0
.
contiguous
(),
16
)
residual
=
maybe_align
(
residual
.
contiguous
(),
16
)
if
residual
is
not
None
else
None
gamma
=
maybe_align
(
gamma
.
contiguous
(),
16
)
beta
=
maybe_align
(
beta
.
contiguous
(),
16
)
if
beta
is
not
None
else
None
rowscale
=
maybe_align
(
rowscale
.
contiguous
(),
16
)
if
rowscale
is
not
None
else
None
colscale
=
maybe_align
(
colscale
.
contiguous
(),
16
)
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_forward
(
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
,
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
x0_saved
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
)
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
has_residual
=
residual
is
not
None
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_beta
=
beta
is
not
None
if
not
return_dmask
:
return
(
zmat
.
view
(
x0
.
shape
)
if
not
prenorm
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
))
)
else
:
dmask
=
(
dmask
.
view
(
x0
.
shape
)
if
dropout_p
>
0.0
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
)
)
ctx
.
mark_non_differentiable
(
dmask
)
return
(
(
zmat
.
view
(
x0
.
shape
),
dmask
)
if
not
prenorm
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
),
dmask
)
)
@
staticmethod
def
backward
(
ctx
,
dz
,
*
args
):
# assert dz.is_contiguous()
dz
=
maybe_align
(
dz
.
contiguous
(),
16
)
# this happens!
dx
=
maybe_align
(
args
[
0
].
contiguous
(),
16
)
if
ctx
.
prenorm
else
None
x
,
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
=
ctx
.
saved_tensors
# x0 is None if colscale is None
dropout_p
=
ctx
.
dropout_p
has_residual
=
ctx
.
has_residual
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
,
ctx
.
is_rms_norm
,
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
return
(
dx0
,
dresidual
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
,
)
class
DropoutAddLayerNormSubsetFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
,
):
x0
=
maybe_align
(
x0
.
contiguous
(),
16
)
residual
=
maybe_align
(
residual
.
contiguous
(),
16
)
if
residual
is
not
None
else
None
gamma
=
maybe_align
(
gamma
.
contiguous
(),
16
)
beta
=
maybe_align
(
beta
.
contiguous
(),
16
)
if
beta
is
not
None
else
None
colscale
=
maybe_align
(
colscale
.
contiguous
(),
16
)
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_subset_forward
(
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
is_rms_norm
,
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
x_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
ctx
.
save_for_backward
(
xmat
.
view
(
x_shape
),
x0_saved
,
dmask
,
gamma
,
mu
,
rsigma
,
colscale
,
x0_subset
,
out_subset
)
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
rowscale_const
=
rowscale_const
ctx
.
x0_numrows
=
x0
.
shape
[:
-
1
].
numel
()
ctx
.
has_residual
=
residual
is
not
None
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_beta
=
beta
is
not
None
z_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
if
not
return_dmask
:
return
zmat
.
view
(
z_shape
)
if
not
prenorm
else
(
zmat
.
view
(
z_shape
),
xmat
.
view
(
x0
.
shape
))
else
:
z
=
zmat
.
view
(
z_shape
)
dmask
=
(
dmask
.
view
(
x0
.
shape
)
if
dropout_p
>
0.0
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
)
)
ctx
.
mark_non_differentiable
(
dmask
)
return
(
z
,
dmask
)
if
not
prenorm
else
(
z
,
xmat
.
view
(
x_shape
),
dmask
)
@
staticmethod
def
backward
(
ctx
,
dz
,
*
args
):
# assert dz.is_contiguous()
dz
=
maybe_align
(
dz
.
contiguous
(),
16
)
# this happens!
dx
=
maybe_align
(
args
[
0
].
contiguous
(),
16
)
if
ctx
.
prenorm
else
None
x
,
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
colscale
,
x0_subset
,
out_subset
=
ctx
.
saved_tensors
# x0 is None if colscale is None
dropout_p
=
ctx
.
dropout_p
has_residual
=
ctx
.
has_residual
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
ctx
.
rowscale_const
,
ctx
.
x0_numrows
,
has_residual
,
ctx
.
is_rms_norm
,
)
dx0
=
dx0mat
.
view
(
-
1
,
*
x
.
shape
[
1
:])
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
return
(
dx0
,
dresidual
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
class
DropoutAddLayerNormParallelResidualFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
,
):
x0
=
maybe_align
(
x0
.
contiguous
(),
16
)
x1
=
maybe_align
(
x1
.
contiguous
(),
16
)
if
x1
is
not
None
else
None
residual
=
maybe_align
(
residual
.
contiguous
(),
16
)
if
residual
is
not
None
else
None
gamma0
=
maybe_align
(
gamma0
.
contiguous
(),
16
)
beta0
=
maybe_align
(
beta0
.
contiguous
(),
16
)
if
beta0
is
not
None
else
None
gamma1
=
maybe_align
(
gamma1
.
contiguous
(),
16
)
if
gamma1
is
not
None
else
None
beta1
=
maybe_align
(
beta1
.
contiguous
(),
16
)
if
beta1
is
not
None
else
None
(
z0mat
,
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
,
)
=
_dropout_add_layer_norm_parallel_residual_forward
(
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
,
)
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
dmask0
,
dmask1
,
gamma0
,
gamma1
,
mu
,
rsigma
)
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
has_x1
=
x1
is
not
None
ctx
.
has_residual
=
residual
is
not
None
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_beta
=
beta0
is
not
None
z
=
(
z0mat
.
view
(
x0
.
shape
),
z1mat
.
view
(
x0
.
shape
)
if
z1mat
is
not
None
else
None
)
if
not
return_dmask
:
return
z
if
not
prenorm
else
(
*
z
,
xmat
.
view
(
x0
.
shape
))
else
:
dmask0
=
(
dmask0
.
view
(
x0
.
shape
)
if
dropout_p
>
0.0
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
)
)
dmask1
=
(
dmask1
.
view
(
x0
.
shape
)
if
dropout_p
>
0.0
and
x1
is
not
None
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
)
)
ctx
.
mark_non_differentiable
(
dmask0
)
ctx
.
mark_non_differentiable
(
dmask1
)
return
(
(
*
z
,
dmask0
,
dmask1
)
if
not
prenorm
else
(
*
z
,
xmat
.
view
(
x0
.
shape
),
dmask0
,
dmask1
)
)
@
staticmethod
def
backward
(
ctx
,
dz0
,
dz1
,
*
args
):
dz0
=
maybe_align
(
dz0
.
contiguous
(),
16
)
# this happens!
dz1
=
maybe_align
(
dz1
.
contiguous
(),
16
)
if
dz1
is
not
None
else
None
dx
=
maybe_align
(
args
[
0
].
contiguous
(),
16
)
if
ctx
.
prenorm
else
None
x
,
dmask0
,
dmask1
,
gamma0
,
gamma1
,
mu
,
rsigma
=
ctx
.
saved_tensors
dropout_p
=
ctx
.
dropout_p
has_x1
=
ctx
.
has_x1
has_residual
=
ctx
.
has_residual
(
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
,
)
=
_dropout_add_layer_norm_parallel_residual_backward
(
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
ctx
.
is_rms_norm
,
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
return
(
dx0
,
dx1
,
dresidual
,
dgamma0
,
dbeta0
if
ctx
.
has_beta
else
None
,
dgamma1
,
dbeta1
if
ctx
.
has_beta
else
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
def
layer_norm
(
x
,
weight
,
bias
,
epsilon
):
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
bias
,
None
,
None
,
0.0
,
epsilon
,
False
)
def
dropout_add_layer_norm
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormFn
.
apply
(
x0
,
residual
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
,
)
def
dropout_add_layer_norm_subset
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormSubsetFn
.
apply
(
x0
,
residual
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
,
)
def
dropout_add_layer_norm_parallel_residual
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormParallelResidualFn
.
apply
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
,
)
class
DropoutAddLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
prenorm
=
prenorm
self
.
p
=
p
self
.
eps
=
eps
self
.
residual_in_fp32
=
residual_in_fp32
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
reset_parameters
()
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
x0
,
residual
=
None
):
return
dropout_add_layer_norm
(
x0
,
residual
,
self
.
weight
,
self
.
bias
,
self
.
p
if
self
.
training
else
0.0
,
self
.
eps
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
,
)
flash_attn/ops/rms_norm.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import
torch
from
torch.nn
import
init
from
flash_attn.ops.layer_norm
import
(
DropoutAddLayerNormFn
,
DropoutAddLayerNormParallelResidualFn
,
DropoutAddLayerNormSubsetFn
,
)
def
rms_norm
(
x
,
weight
,
epsilon
):
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
None
,
None
,
None
,
0.0
,
epsilon
,
False
,
False
,
True
)
def
dropout_add_rms_norm
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormFn
.
apply
(
x0
,
residual
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
,
)
def
dropout_add_rms_norm_subset
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormSubsetFn
.
apply
(
x0
,
residual
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
,
)
def
dropout_add_rms_norm_parallel_residual
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormParallelResidualFn
.
apply
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
,
)
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
"bias"
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x
):
return
rms_norm
(
x
,
self
.
weight
,
self
.
eps
)
class
DropoutAddRMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
prenorm
=
prenorm
self
.
p
=
p
self
.
eps
=
eps
self
.
residual_in_fp32
=
residual_in_fp32
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
"bias"
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x0
,
residual
=
None
):
return
dropout_add_rms_norm
(
x0
,
residual
,
self
.
weight
,
None
,
self
.
p
if
self
.
training
else
0.0
,
self
.
eps
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
,
)
flash_attn/ops/triton/__init__.py
deleted
100644 → 0
View file @
5018ac6a
flash_attn/ops/triton/cross_entropy.py
deleted
100644 → 0
View file @
5018ac6a
# Copyright (c) 2023, Tri Dao.
from
typing
import
Tuple
,
Optional
,
Union
import
torch
import
triton
import
triton.language
as
tl
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if
"all_gather_into_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
all_gather_into_tensor
=
torch
.
distributed
.
_all_gather_base
@
triton
.
heuristics
(
{
"HAS_SMOOTHING"
:
lambda
args
:
args
[
"smoothing"
]
>
0.0
,
}
)
@
triton
.
jit
def
cross_entropy_fwd_kernel
(
loss_ptr
,
# data ptrs
lse_ptr
,
z_loss_ptr
,
logits_ptr
,
labels_ptr
,
smoothing
,
logit_scale
,
lse_square_scale
,
ignore_index
,
total_classes
,
class_start_idx
,
# Useful for tensor parallel when each rank only has a subset of classes
n_cols
,
# shapes
n_rows
,
logits_row_stride
,
# strides
BLOCK_SIZE
:
tl
.
constexpr
,
HAS_SMOOTHING
:
tl
.
constexpr
,
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
SPLIT
:
tl
.
constexpr
,
):
row_idx
=
tl
.
program_id
(
0
)
col_block_idx
=
tl
.
program_id
(
1
)
logits_ptr
=
logits_ptr
+
row_idx
*
logits_row_stride
.
to
(
tl
.
int64
)
col_offsets
=
col_block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
label_idx
=
tl
.
load
(
labels_ptr
+
row_idx
)
logits
=
tl
.
load
(
logits_ptr
+
col_offsets
,
mask
=
col_offsets
<
n_cols
,
other
=-
float
(
"inf"
)).
to
(
tl
.
float32
)
*
logit_scale
max_logits
=
tl
.
max
(
logits
,
0
)
if
HAS_SMOOTHING
:
sum_logits
=
tl
.
sum
(
tl
.
where
(
col_offsets
<
n_cols
,
logits
,
0.0
),
0
)
lse
=
tl
.
log
(
tl
.
sum
(
tl
.
exp
(
logits
-
max_logits
),
0
))
+
max_logits
tl
.
store
(
lse_ptr
+
col_block_idx
*
n_rows
+
row_idx
,
lse
)
if
label_idx
==
ignore_index
:
loss
=
0.0
z_loss
=
0.0
else
:
label_idx
-=
class_start_idx
if
label_idx
>=
col_block_idx
*
BLOCK_SIZE
and
label_idx
<
min
(
n_cols
,
(
col_block_idx
+
1
)
*
BLOCK_SIZE
):
logits_label
=
tl
.
load
(
logits_ptr
+
label_idx
)
*
logit_scale
if
HAS_SMOOTHING
:
loss
=
(
(
lse
if
not
SPLIT
else
0.0
)
-
smoothing
*
sum_logits
/
total_classes
-
(
1
-
smoothing
)
*
logits_label
)
else
:
loss
=
(
lse
if
not
SPLIT
else
0.0
)
-
logits_label
else
:
# If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
if
HAS_SMOOTHING
:
loss
=
smoothing
*
((
lse
if
not
SPLIT
else
0.0
)
-
sum_logits
/
total_classes
)
else
:
loss
=
0.0
if
not
SPLIT
:
z_loss
=
lse_square_scale
*
lse
*
lse
loss
+=
z_loss
else
:
z_loss
=
0.0
tl
.
store
(
loss_ptr
+
col_block_idx
*
n_rows
+
row_idx
,
loss
)
if
not
SPLIT
:
tl
.
store
(
z_loss_ptr
+
col_block_idx
*
n_rows
+
row_idx
,
z_loss
)
@
triton
.
heuristics
(
{
"HAS_SMOOTHING"
:
lambda
args
:
args
[
"smoothing"
]
>
0.0
,
}
)
@
triton
.
jit
def
cross_entropy_bwd_kernel
(
dlogits_ptr
,
# data ptrs
dloss_ptr
,
logits_ptr
,
lse_ptr
,
labels_ptr
,
smoothing
,
logit_scale
,
lse_square_scale
,
ignore_index
,
total_classes
,
class_start_idx
,
# Useful for tensor parallel when each rank only has a subset of classes
n_cols
,
# shapes
logits_row_stride
,
# strides
dlogits_row_stride
,
dloss_row_stride
,
BLOCK_SIZE
:
tl
.
constexpr
,
HAS_SMOOTHING
:
tl
.
constexpr
,
):
row_idx
=
tl
.
program_id
(
0
)
col_block_idx
=
tl
.
program_id
(
1
)
logits_ptr
=
logits_ptr
+
row_idx
*
logits_row_stride
.
to
(
tl
.
int64
)
dlogits_ptr
=
dlogits_ptr
+
row_idx
*
dlogits_row_stride
.
to
(
tl
.
int64
)
col_offsets
=
col_block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
label_idx
=
tl
.
load
(
labels_ptr
+
row_idx
)
if
label_idx
!=
ignore_index
:
dloss
=
tl
.
load
(
dloss_ptr
+
row_idx
*
dloss_row_stride
)
else
:
dloss
=
0.0
logits
=
tl
.
load
(
logits_ptr
+
col_offsets
,
mask
=
col_offsets
<
n_cols
,
other
=-
float
(
"inf"
)).
to
(
tl
.
float32
)
*
logit_scale
lse
=
tl
.
load
(
lse_ptr
+
row_idx
)
probs
=
tl
.
exp
(
logits
-
lse
)
probs
+=
2.0
*
lse_square_scale
*
lse
*
probs
label_idx
-=
class_start_idx
if
HAS_SMOOTHING
:
smooth_positive
=
1.0
-
smoothing
smooth_negative
=
smoothing
/
total_classes
probs
=
tl
.
where
(
col_offsets
==
label_idx
,
probs
-
(
1
-
smoothing
),
probs
)
-
smooth_negative
else
:
probs
=
tl
.
where
(
col_offsets
==
label_idx
,
probs
-
1.0
,
probs
)
tl
.
store
(
dlogits_ptr
+
col_offsets
,
(
dloss
*
logit_scale
)
*
probs
,
mask
=
col_offsets
<
n_cols
)
class
CrossEntropyLoss
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
logits
,
labels
,
smoothing
=
0.0
,
logit_scale
=
1.0
,
lse_square_scale
=
0.0
,
ignore_index
=-
100
,
inplace_backward
=
False
,
process_group
=
None
,
):
n_rows
,
n_cols
=
logits
.
shape
assert
labels
.
shape
==
(
n_rows
,)
world_size
=
1
if
process_group
is
None
else
torch
.
distributed
.
get_world_size
(
process_group
)
total_classes
=
world_size
*
n_cols
rank
=
0
if
process_group
is
None
else
torch
.
distributed
.
get_rank
(
process_group
)
class_start_idx
=
rank
*
n_cols
if
logits
.
stride
(
-
1
)
!=
1
:
logits
=
logits
.
contiguous
()
# Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
MAX_BLOCK_SIZE
=
64
*
1024
BLOCK_SIZE
=
min
(
triton
.
next_power_of_2
(
n_cols
),
MAX_BLOCK_SIZE
)
num_warps
=
(
4
if
BLOCK_SIZE
<
2048
else
(
8
if
BLOCK_SIZE
<
8192
else
(
16
if
BLOCK_SIZE
<
128
*
1024
else
32
))
)
# We may split the lse computation across multiple blocks, then do a reduction
# lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
# where having just one thread block processing more than 64k elements is slow.
split
=
world_size
>
1
or
n_cols
>
MAX_BLOCK_SIZE
n_splits
=
(
n_cols
+
BLOCK_SIZE
-
1
)
//
BLOCK_SIZE
loss_shape
=
(
n_splits
,
n_rows
)
if
n_splits
>
1
else
(
n_rows
,)
losses
=
torch
.
empty
(
*
loss_shape
,
dtype
=
torch
.
float
,
device
=
logits
.
device
)
lse
=
torch
.
empty
(
*
loss_shape
,
dtype
=
torch
.
float
,
device
=
logits
.
device
)
z_losses
=
torch
.
empty
(
*
loss_shape
,
dtype
=
torch
.
float
,
device
=
logits
.
device
)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with
torch
.
cuda
.
device
(
logits
.
device
.
index
):
cross_entropy_fwd_kernel
[(
n_rows
,
n_splits
)](
losses
,
# data ptrs
lse
,
z_losses
,
logits
,
labels
,
smoothing
,
logit_scale
,
lse_square_scale
,
ignore_index
,
total_classes
,
class_start_idx
,
n_cols
,
# shapes
n_rows
,
logits
.
stride
(
0
),
# strides
BLOCK_SIZE
=
BLOCK_SIZE
,
# constants
num_warps
=
num_warps
,
SPLIT
=
split
,
)
if
split
:
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
# - predicted logit, and 0 otherwise.
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
# For labels not in the vocab of this partition, losses contains
# -0.1 * sum logit / total_classes.
if
n_splits
>
1
:
lse
=
torch
.
logsumexp
(
lse
,
dim
=
0
)
losses
=
losses
.
sum
(
dim
=
0
)
if
world_size
>
1
:
lse_allgather
=
torch
.
empty
(
world_size
,
n_rows
,
dtype
=
lse
.
dtype
,
device
=
lse
.
device
)
torch
.
distributed
.
all_gather_into_tensor
(
lse_allgather
,
lse
,
group
=
process_group
)
handle_losses
=
torch
.
distributed
.
all_reduce
(
losses
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
,
async_op
=
True
)
lse
=
torch
.
logsumexp
(
lse_allgather
,
dim
=
0
)
handle_losses
.
wait
()
# After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
# we just have to add the (global) lse.
# If there's smoothing=0.1, the total losses are
# -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
# Again, we just have to add the (global) lse.
losses
+=
lse
if
lse_square_scale
!=
0.0
:
z_losses
=
lse_square_scale
*
lse
.
square
()
z_losses
.
masked_fill_
(
labels
==
ignore_index
,
0.0
)
losses
+=
z_losses
else
:
z_losses
=
torch
.
zeros_like
(
losses
)
losses
.
masked_fill_
(
labels
==
ignore_index
,
0.0
)
ctx
.
save_for_backward
(
logits
,
lse
,
labels
)
ctx
.
mark_non_differentiable
(
z_losses
)
ctx
.
smoothing
=
smoothing
ctx
.
logit_scale
=
logit_scale
ctx
.
lse_square_scale
=
lse_square_scale
ctx
.
ignore_index
=
ignore_index
ctx
.
total_classes
=
total_classes
ctx
.
class_start_idx
=
class_start_idx
ctx
.
inplace_backward
=
inplace_backward
return
losses
,
z_losses
@
staticmethod
def
backward
(
ctx
,
grad_losses
,
grad_z_losses
):
del
grad_z_losses
# z_losses are only for logging.
logits
,
lse
,
labels
=
ctx
.
saved_tensors
dlogits
=
logits
if
ctx
.
inplace_backward
else
torch
.
empty_like
(
logits
)
n_rows
,
n_cols
=
logits
.
shape
BLOCK_SIZE
=
min
(
triton
.
next_power_of_2
(
n_cols
),
4
*
1024
)
num_warps
=
4
if
BLOCK_SIZE
<
2048
else
(
8
if
BLOCK_SIZE
<
8192
else
16
)
grid
=
lambda
META
:
(
n_rows
,
triton
.
cdiv
(
n_cols
,
META
[
"BLOCK_SIZE"
]))
# noqa
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with
torch
.
cuda
.
device
(
logits
.
device
.
index
):
cross_entropy_bwd_kernel
[
grid
](
dlogits
,
# data ptrs
grad_losses
,
logits
,
lse
,
labels
,
ctx
.
smoothing
,
ctx
.
logit_scale
,
ctx
.
lse_square_scale
,
ctx
.
ignore_index
,
ctx
.
total_classes
,
ctx
.
class_start_idx
,
n_cols
,
# shapes
logits
.
stride
(
0
),
# strides
dlogits
.
stride
(
0
),
grad_losses
.
stride
(
0
),
BLOCK_SIZE
=
BLOCK_SIZE
,
# constants
num_warps
=
num_warps
,
)
return
dlogits
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
cross_entropy_loss
(
logits
:
torch
.
Tensor
,
labels
:
torch
.
Tensor
,
label_smoothing
:
float
=
0.0
,
logit_scale
:
float
=
1.0
,
lse_square_scale
:
float
=
0.0
,
ignore_index
=-
100
,
inplace_backward
:
bool
=
False
,
process_group
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Arguments:
logits: (batch, vocab_size)
labels: (batch,)
label_smoothing: float
logit_scale: float. Multiply logits by this scale before calculating the loss.
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss will be aggregated across processes.
Returns:
losses: (batch,), float
z_losses: (batch,), float
"""
return
CrossEntropyLoss
.
apply
(
logits
,
labels
,
label_smoothing
,
logit_scale
,
lse_square_scale
,
ignore_index
,
inplace_backward
,
process_group
,
)
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment