Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ae856f3a
Commit
ae856f3a
authored
Mar 28, 2024
by
Woosuk Kwon
Browse files
Remove unnecessary files
parent
6ac8e63a
Changes
45
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
7084 deletions
+0
-7084
vllm_flash_attn/models/llama.py
vllm_flash_attn/models/llama.py
+0
-422
vllm_flash_attn/models/opt.py
vllm_flash_attn/models/opt.py
+0
-116
vllm_flash_attn/models/vit.py
vllm_flash_attn/models/vit.py
+0
-373
vllm_flash_attn/modules/__init__.py
vllm_flash_attn/modules/__init__.py
+0
-0
vllm_flash_attn/modules/block.py
vllm_flash_attn/modules/block.py
+0
-397
vllm_flash_attn/modules/embedding.py
vllm_flash_attn/modules/embedding.py
+0
-216
vllm_flash_attn/modules/mha.py
vllm_flash_attn/modules/mha.py
+0
-1020
vllm_flash_attn/modules/mlp.py
vllm_flash_attn/modules/mlp.py
+0
-191
vllm_flash_attn/ops/__init__.py
vllm_flash_attn/ops/__init__.py
+0
-0
vllm_flash_attn/ops/activations.py
vllm_flash_attn/ops/activations.py
+0
-135
vllm_flash_attn/ops/fused_dense.py
vllm_flash_attn/ops/fused_dense.py
+0
-688
vllm_flash_attn/ops/layer_norm.py
vllm_flash_attn/ops/layer_norm.py
+0
-800
vllm_flash_attn/ops/rms_norm.py
vllm_flash_attn/ops/rms_norm.py
+0
-174
vllm_flash_attn/ops/triton/__init__.py
vllm_flash_attn/ops/triton/__init__.py
+0
-1
vllm_flash_attn/ops/triton/cross_entropy.py
vllm_flash_attn/ops/triton/cross_entropy.py
+0
-320
vllm_flash_attn/ops/triton/k_activations.py
vllm_flash_attn/ops/triton/k_activations.py
+0
-162
vllm_flash_attn/ops/triton/layer_norm.py
vllm_flash_attn/ops/triton/layer_norm.py
+0
-1086
vllm_flash_attn/ops/triton/linear.py
vllm_flash_attn/ops/triton/linear.py
+0
-594
vllm_flash_attn/ops/triton/mlp.py
vllm_flash_attn/ops/triton/mlp.py
+0
-149
vllm_flash_attn/ops/triton/rotary.py
vllm_flash_attn/ops/triton/rotary.py
+0
-240
No files found.
vllm_flash_attn/models/llama.py
deleted
100644 → 0
View file @
6ac8e63a
# Copyright (c) 2023, Tri Dao.
import
json
import
math
import
os
import
re
from
collections
import
OrderedDict
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Union
import
torch
import
torch.nn.functional
as
F
from
sentencepiece
import
SentencePieceProcessor
from
transformers
import
GPT2Config
,
LlamaConfig
from
einops
import
rearrange
def
remap_state_dict_meta_llama
(
state_dict
:
Dict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Convert the state_dict in Meta format to standard GPT format.
This function modifies state_dict in place.
"""
def
key_mapping_layers
(
key
):
return
f
"transformer.
{
key
}
"
if
not
key
.
startswith
(
"output."
)
else
key
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
"^transformer.tok_embeddings."
,
"transformer.embeddings.word_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
"transformer.embeddings.word_embeddings.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
(
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
if
getattr
(
config
,
"tie_word_embeddings"
):
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
else
:
output_embeddings
=
state_dict
.
pop
(
"output.weight"
)
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size
=
(
math
.
ceil
(
output_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
"lm_head.weight"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^transformer.norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).attention_norm."
,
r
"transformer.layers.\1.norm1."
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).ffn_norm."
,
r
"transformer.layers.\1.norm2."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
for
l
in
range
(
config
.
n_layer
):
w1
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.feed_forward.w1.weight"
)
w3
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.feed_forward.w3.weight"
)
# Our ordering is different
state_dict
[
f
"transformer.layers.
{
l
}
.mlp.fc1.weight"
]
=
torch
.
cat
([
w3
,
w1
],
dim
=
0
)
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).feed_forward.w2."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.wq.weight"
)
Wk
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.wk.weight"
)
Wv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.wv.weight"
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
# We don't store these
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.inner_attention.rope.freqs"
,
None
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).attention.wo."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
.
pop
(
"transformer.rope.freqs"
,
None
)
return
state_dict
def
remap_state_dict_hf_llama
(
state_dict
:
Dict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Convert the state_dict in Hugging Face format to standard GPT format.
This function modifies state_dict in place.
"""
# Embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
"^model.embed_tokens."
,
"transformer.embeddings.word_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
"transformer.embeddings.word_embeddings.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
(
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
# LM head
if
getattr
(
config
,
"tie_word_embeddings"
):
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
else
:
output_embeddings
=
state_dict
.
pop
(
"lm_head.weight"
)
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size
=
(
math
.
ceil
(
output_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
"lm_head.weight"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
# MLP
for
l
in
range
(
config
.
n_layer
):
# Fusing weights this way based on difference in the following:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
# https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
w1
=
state_dict
.
pop
(
f
"model.layers.
{
l
}
.mlp.gate_proj.weight"
)
w3
=
state_dict
.
pop
(
f
"model.layers.
{
l
}
.mlp.up_proj.weight"
)
state_dict
[
f
"transformer.layers.
{
l
}
.mlp.fc1.weight"
]
=
torch
.
cat
([
w3
,
w1
],
dim
=
0
)
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
"^model.layers.(\d+).mlp.down_proj."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^model.norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
"^model.layers.(\d+).input_layernorm."
,
r
"transformer.layers.\1.norm1."
,
key
,
)
key
=
re
.
sub
(
r
"^model.layers.(\d+).post_attention_layernorm."
,
r
"transformer.layers.\1.norm2."
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
def
inv_permute
(
w
):
# Inverse of permute implemented in:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
return
rearrange
(
w
,
"(h two d) n -> (h d two) n"
,
d
=
config
.
n_embd
//
config
.
n_head
//
2
,
two
=
2
)
# Attention
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
"model.layers.
{
l
}
.self_attn.q_proj.weight"
)
Wk
=
state_dict
.
pop
(
f
"model.layers.
{
l
}
.self_attn.k_proj.weight"
)
Wv
=
state_dict
.
pop
(
f
"model.layers.
{
l
}
.self_attn.v_proj.weight"
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
]
=
torch
.
cat
(
[
inv_permute
(
Wq
),
inv_permute
(
Wk
),
Wv
],
dim
=
0
)
# We don't store these
state_dict
.
pop
(
f
"model.layers.
{
l
}
.self_attn.rotary_emb.inv_freq"
,
None
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
"^model.layers.(\d+).self_attn.o_proj."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
inv_remap_state_dict_hf_llama
(
state_dict
:
Dict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Convert the state_dict in standard GPT format to Hugging Face format.
This function is meant to be the inverse of remap_state_dict_hf_llama, up to a
multiplier pad in the embedding and lm_head. That is if the original embedding
isn't a multiple of pad_vocab_size_multiple, then
inv_remap_state_dict_hf_llama(remap_state_dict_hf_llama(state_dict)) != state_dict.
This function modifies state_dict in place.
"""
# Embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
"^transformer.embeddings.word_embeddings."
,
"model.embed_tokens."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
"model.embed_tokens.weight"
)
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
(
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
"model.embed_tokens.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
# LM head
if
getattr
(
config
,
"tie_word_embeddings"
):
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"model.embed_tokens.weight"
]
else
:
output_embeddings
=
state_dict
.
pop
(
"lm_head.weight"
)
vocab_size
=
(
math
.
ceil
(
output_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
"lm_head.weight"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
# MLP
for
l
in
range
(
config
.
n_layer
):
w3
,
w1
=
torch
.
chunk
(
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.mlp.fc1.weight"
),
chunks
=
2
,
dim
=
0
)
state_dict
[
f
"model.layers.
{
l
}
.mlp.gate_proj.weight"
]
=
w1
state_dict
[
f
"model.layers.
{
l
}
.mlp.up_proj.weight"
]
=
w3
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.fc2."
,
r
"model.layers.\1.mlp.down_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^transformer.ln_f."
,
r
"model.norm."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).norm1."
,
r
"model.layers.\1.input_layernorm."
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).norm2."
,
r
"model.layers.\1.post_attention_layernorm."
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
def
permute
(
w
):
return
rearrange
(
w
,
"(h d two) n -> (h two d) n"
,
d
=
config
.
n_embd
//
config
.
n_head
//
2
,
two
=
2
)
n_head
=
config
.
n_head
n_head_kv
=
getattr
(
config
,
"n_head_kv"
,
n_head
)
embed_dim
=
config
.
hidden_size
head_dim
=
embed_dim
//
n_head
q_dim
=
n_head
*
head_dim
k_dim
=
v_dim
=
n_head_kv
*
head_dim
# Attention
for
l
in
range
(
config
.
n_layer
):
Wqkv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
)
Wq
=
Wqkv
[:
q_dim
]
Wk
=
Wqkv
[
q_dim
:
q_dim
+
k_dim
]
Wv
=
Wqkv
[
q_dim
+
k_dim
:
q_dim
+
k_dim
+
v_dim
]
state_dict
[
f
"model.layers.
{
l
}
.self_attn.q_proj.weight"
]
=
permute
(
Wq
)
state_dict
[
f
"model.layers.
{
l
}
.self_attn.k_proj.weight"
]
=
permute
(
Wk
)
state_dict
[
f
"model.layers.
{
l
}
.self_attn.v_proj.weight"
]
=
Wv
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.inner_attention.rope.freqs"
,
None
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).mixer.out_proj."
,
r
"model.layers.\1.self_attn.o_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
config_from_meta_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
LlamaConfig
:
"""Load a LlamaConfig from a checkpoint path."""
with
open
(
Path
(
checkpoint_path
)
/
model_name
/
"params.json"
)
as
f
:
params
=
json
.
load
(
f
)
config
=
LlamaConfig
(
hidden_size
=
params
[
"dim"
],
intermediate_size
=
None
,
num_attention_heads
=
params
[
"n_heads"
],
num_hidden_layers
=
params
[
"n_layers"
],
rms_norm_eps
=
params
[
"norm_eps"
],
num_key_value_heads
=
params
.
get
(
"n_kv_heads"
,
None
),
)
multiple_of
=
params
.
get
(
"multiple_of"
,
1
)
ffn_dim_multiplier
=
params
.
get
(
"ffn_dim_multiplier"
,
None
)
# Compute the hidden dimension of the MLP
# https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L224
intermediate_size
=
4
*
config
.
hidden_size
# https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L195-L199
intermediate_size
=
int
(
2
*
intermediate_size
/
3
)
# custom dim factor multiplier
if
ffn_dim_multiplier
is
not
None
:
intermediate_size
=
int
(
ffn_dim_multiplier
*
intermediate_size
)
intermediate_size
=
multiple_of
*
((
intermediate_size
+
multiple_of
-
1
)
//
multiple_of
)
config
.
intermediate_size
=
intermediate_size
if
"rope_theta"
in
params
:
config
.
rotary_emb_base
=
params
[
"rope_theta"
]
config
.
vocab_size
=
32000
# some CodeLLaMa have vocab_size 32000, some 32016
# Sadly it's not specified in the `params.json` file :(
tokenizer
=
Path
(
checkpoint_path
)
/
model_name
/
"tokenizer.model"
if
tokenizer
.
is_file
():
config
.
vocab_size
=
SentencePieceProcessor
(
str
(
tokenizer
)).
vocab_size
()
return
config
def
config_from_hf_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
LlamaConfig
:
return
LlamaConfig
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
/
"config.json"
)
def
config_from_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
,
checkpoint_format
=
"meta"
)
->
LlamaConfig
:
if
checkpoint_format
==
"meta"
:
return
config_from_meta_checkpoint
(
checkpoint_path
,
model_name
)
else
:
return
config_from_hf_checkpoint
(
checkpoint_path
,
model_name
)
def
state_dicts_from_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
List
[
dict
]:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return
[
torch
.
load
(
path
,
map_location
=
"cpu"
)
for
path
in
sorted
((
Path
(
checkpoint_path
)
/
model_name
).
glob
(
"consolidated.*.pth"
))
]
def
llama_config_to_gpt2_config
(
llama_config
:
LlamaConfig
)
->
GPT2Config
:
return
GPT2Config
(
vocab_size
=
llama_config
.
vocab_size
,
n_positions
=
0
,
# No absolute position embedding
n_embd
=
llama_config
.
hidden_size
,
n_layer
=
llama_config
.
num_hidden_layers
,
n_head
=
llama_config
.
num_attention_heads
,
n_inner
=
llama_config
.
intermediate_size
,
activation_function
=
"swiglu"
,
# Hardcode since HF calls it 'silu'
# Llama doesn't have dropout, idk if it's because they only release the inference code
resid_pdrop
=
0.0
,
embd_pdrop
=
0.0
,
attn_pdrop
=
0.0
,
layer_norm_epsilon
=
llama_config
.
rms_norm_eps
,
initializer_range
=
llama_config
.
initializer_range
,
bos_token_id
=
llama_config
.
bos_token_id
,
eos_token_id
=
llama_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
pad_token_id
=
llama_config
.
pad_token_id
,
# Idk if this does anything
rms_norm
=
True
,
rotary_emb_fraction
=
1.0
,
rotary_emb_interleaved
=
True
,
tie_word_embeddings
=
False
,
qkv_proj_bias
=
False
,
out_proj_bias
=
False
,
mlp_fc1_bias
=
False
,
mlp_fc2_bias
=
False
,
rotary_emb_base
=
getattr
(
llama_config
,
"rotary_emb_base"
,
10000.0
),
n_head_kv
=
llama_config
.
num_key_value_heads
,
)
vllm_flash_attn/models/opt.py
deleted
100644 → 0
View file @
6ac8e63a
# Copyright (c) 2023, Tri Dao.
import
math
import
re
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
transformers
import
GPT2Config
,
OPTConfig
def
remap_state_dict_hf_opt
(
state_dict
,
config
):
def
key_mapping_model
(
key
):
key
=
re
.
sub
(
r
"^model.decoder."
,
"transformer."
,
key
)
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
key
=
re
.
sub
(
r
"^decoder."
,
"transformer."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_model
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding and position embedding
def
key_mapping_emb
(
key
):
key
=
re
.
sub
(
r
"^transformer.embed_tokens."
,
"transformer.embeddings.word_embeddings."
,
key
)
# The OPT-350m model uses has project_in and project_out
key
=
re
.
sub
(
r
"^transformer.project_in."
,
"transformer.embeddings.project_in."
,
key
)
key
=
re
.
sub
(
r
"^transformer.project_out."
,
"project_out."
,
key
)
key
=
re
.
sub
(
r
"^transformer.embed_positions."
,
"transformer.embeddings.position_embeddings."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# OPT uses the first 2 indices of pos_emb for padding tokens
pos_embeddings
=
state_dict
.
pop
(
"transformer.embeddings.position_embeddings.weight"
)
state_dict
[
"transformer.embeddings.position_embeddings.weight"
]
=
pos_embeddings
[
2
:]
word_embeddings
=
state_dict
.
pop
(
"transformer.embeddings.word_embeddings.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^transformer.final_layer_norm."
,
r
"transformer.ln_f."
,
key
)
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
key
=
re
.
sub
(
r
"^transformer.layer_norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).self_attn_layer_norm."
,
r
"transformer.layers.\1.norm1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).final_layer_norm."
,
r
"transformer.layers.\1.norm2."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).fc(1|2)."
,
r
"transformer.layers.\1.mlp.fc\2."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.q_proj.weight"
)
Wk
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.k_proj.weight"
)
Wv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.v_proj.weight"
)
bq
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.q_proj.bias"
)
bk
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.k_proj.bias"
)
bv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.v_proj.bias"
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.bias"
]
=
torch
.
cat
([
bq
,
bk
,
bv
],
dim
=
0
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).self_attn.out_proj."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
opt_config_to_gpt2_config
(
opt_config
:
OPTConfig
)
->
GPT2Config
:
assert
opt_config
.
layerdrop
==
0.0
assert
opt_config
.
layer_norm_elementwise_affine
word_embed_proj_dim
=
(
None
if
opt_config
.
word_embed_proj_dim
==
opt_config
.
hidden_size
else
opt_config
.
word_embed_proj_dim
)
return
GPT2Config
(
vocab_size
=
opt_config
.
vocab_size
,
n_positions
=
opt_config
.
max_position_embeddings
,
n_embd
=
opt_config
.
hidden_size
,
n_layer
=
opt_config
.
num_hidden_layers
,
n_head
=
opt_config
.
num_attention_heads
,
n_inner
=
opt_config
.
ffn_dim
,
activation_function
=
opt_config
.
activation_function
,
resid_pdrop
=
opt_config
.
dropout
,
# HF's implementation of OPT doesn't seem to have embedding dropout
embd_pdrop
=
opt_config
.
dropout
,
attn_pdrop
=
opt_config
.
attention_dropout
,
initializer_range
=
opt_config
.
init_std
,
bos_token_id
=
opt_config
.
bos_token_id
,
eos_token_id
=
opt_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
prenorm
=
opt_config
.
do_layer_norm_before
,
word_embed_proj_dim
=
word_embed_proj_dim
,
)
vllm_flash_attn/models/vit.py
deleted
100644 → 0
View file @
6ac8e63a
# Copyright (c) 2022, Tri Dao.
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import
math
import
re
from
collections
import
OrderedDict
from
copy
import
deepcopy
from
functools
import
partial
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
timm.models.helpers
import
named_apply
from
torch.nn.init
import
trunc_normal_
from
torchvision.ops
import
StochasticDepth
from
flash_attn.layers.patch_embed
import
PatchEmbed
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
FusedMLP
,
Mlp
try
:
from
flash_attn.ops.triton.layer_norm
import
layer_norm_fn
except
ImportError
:
layer_norm_fn
=
None
def
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop
,
use_flash_attn
,
fused_bias_fc
,
cross_attn
=
False
):
mixer_cls
=
partial
(
MHA
,
num_heads
=
num_heads
,
cross_attn
=
cross_attn
,
qkv_proj_bias
=
qkv_bias
,
dropout
=
attn_drop
,
fused_bias_fc
=
fused_bias_fc
,
use_flash_attn
=
use_flash_attn
,
)
return
mixer_cls
def
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_mlp
):
inner_dim
=
int
(
embed_dim
*
mlp_ratio
)
if
not
fused_mlp
:
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
act_layer
())
else
:
mlp_cls
=
partial
(
FusedMLP
,
hidden_features
=
inner_dim
)
return
mlp_cls
def
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path1
,
drop_path2
,
norm_layer
,
act_layer
,
use_flash_attn
,
fused_bias_fc
,
fused_mlp
,
fused_dropout_add_ln
,
layer_idx
=
None
,
n_layer
=
None
,
last_layer_subset
=
False
,
):
mixer_cls
=
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop_rate
,
use_flash_attn
,
fused_bias_fc
,
cross_attn
=
(
last_layer_subset
and
layer_idx
==
n_layer
-
1
),
)
mlp_cls
=
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_mlp
)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
block
=
Block
(
embed_dim
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_layer
,
prenorm
=
True
,
resid_dropout1
=
drop_rate
,
resid_dropout2
=
drop_rate
,
drop_path1
=
drop_path1
,
drop_path2
=
drop_path2
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
residual_in_fp32
=
True
,
)
return
block
class
VisionTransformer
(
nn
.
Module
):
"""Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
num_classes
=
1000
,
global_pool
=
"token"
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4.0
,
qkv_bias
=
True
,
init_values
=
None
,
class_token
=
True
,
no_embed_class
=
False
,
pre_norm
=
False
,
fc_norm
=
None
,
drop_rate
=
0.0
,
attn_drop_rate
=
0.0
,
drop_path_rate
=
0.0
,
weight_init
=
""
,
embed_layer
=
PatchEmbed
,
norm_layer
=
None
,
act_layer
=
None
,
use_flash_attn
=
False
,
fused_bias_fc
=
False
,
fused_mlp
=
False
,
fused_dropout_add_ln
=
False
,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'token')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
init_values: (float): layer-scale init values
class_token (bool): use class token
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
"""
super
().
__init__
()
assert
global_pool
==
"token"
,
"Only support pooling with CLS token"
assert
class_token
assert
init_values
is
None
,
"LayerScale is not supported yet"
assert
weight_init
==
""
assert
fc_norm
is
None
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
assert
not
pre_norm
use_fc_norm
=
global_pool
==
"avg"
if
fc_norm
is
None
else
fc_norm
norm_layer
=
norm_layer
or
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
act_layer
=
act_layer
or
nn
.
GELU
self
.
num_classes
=
num_classes
self
.
global_pool
=
global_pool
self
.
num_features
=
(
self
.
embed_dim
)
=
embed_dim
# num_features for consistency with other models
self
.
num_prefix_tokens
=
1
if
class_token
else
0
self
.
no_embed_class
=
no_embed_class
patch_embed_extra_kwargs
=
(
{
"fused_bias_fc"
:
fused_bias_fc
}
if
embed_layer
is
PatchEmbed
else
{}
)
self
.
patch_embed
=
embed_layer
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
bias
=
not
pre_norm
,
# disable bias if pre-norm is used (e.g. CLIP)
**
patch_embed_extra_kwargs
,
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
if
class_token
else
None
embed_len
=
num_patches
if
no_embed_class
else
num_patches
+
self
.
num_prefix_tokens
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
randn
(
1
,
embed_len
,
embed_dim
)
*
0.02
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)
]
# stochastic depth decay rule
# We change the order of dropout, residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
# nn.Dropout probabilities are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
self
.
blocks
=
nn
.
ModuleList
(
[
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path1
=
dpr
[
i
-
1
]
if
i
>
0
else
0.0
,
drop_path2
=
dpr
[
i
],
norm_layer
=
norm_layer
,
act_layer
=
act_layer
,
use_flash_attn
=
use_flash_attn
,
fused_bias_fc
=
fused_bias_fc
,
fused_mlp
=
fused_mlp
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
layer_idx
=
i
,
n_layer
=
depth
,
last_layer_subset
=
(
global_pool
==
"token"
),
)
for
i
in
range
(
depth
)
]
)
self
.
dropout
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
drop_path
=
StochasticDepth
(
p
=
dpr
[
-
1
],
mode
=
"row"
)
self
.
norm
=
norm_layer
(
embed_dim
)
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
if
self
.
fused_dropout_add_ln
and
layer_norm_fn
is
None
:
raise
ImportError
(
"Triton is not installed"
)
# Classifier Head
self
.
head
=
nn
.
Linear
(
self
.
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
self
.
init_weights
(
weight_init
)
def
init_weights
(
self
,
mode
=
""
):
assert
mode
==
""
trunc_normal_
(
self
.
pos_embed
,
std
=
0.02
)
if
self
.
cls_token
is
not
None
:
nn
.
init
.
normal_
(
self
.
cls_token
,
std
=
1e-6
)
named_apply
(
init_weights_vit_timm
,
self
)
def
_init_weights
(
self
,
m
):
# this fn left here for compat with downstream users
init_weights_vit_timm
(
m
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
"pos_embed"
,
"cls_token"
}
def
_pos_embed
(
self
,
x
):
if
self
.
no_embed_class
:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x
=
x
+
self
.
pos_embed
if
self
.
cls_token
is
not
None
:
x
=
torch
.
cat
((
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
),
x
),
dim
=
1
)
else
:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if
self
.
cls_token
is
not
None
:
x
=
torch
.
cat
((
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
),
x
),
dim
=
1
)
x
=
x
+
self
.
pos_embed
return
x
def
forward_features
(
self
,
x
,
all_tokens
=
True
):
"""
If all_tokens==False and self.global_pool == 'token', we only return the features for the
cls token.
"""
x
=
self
.
patch_embed
(
x
)
hidden_states
=
self
.
_pos_embed
(
x
)
residual
=
None
if
self
.
global_pool
!=
"token"
or
all_tokens
:
# if True:
for
block
in
self
.
blocks
:
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
else
:
for
block
in
self
.
blocks
[:
-
1
]:
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
# For the last layer, we only want the 1st token of the output. So we do cross-attention
# where the query is the 1st token and the key/value is the whole sequence.
hidden_states
,
residual
=
self
.
blocks
[
-
1
](
hidden_states
,
residual
,
mixer_subset
=
slice
(
0
,
1
)
)
if
not
self
.
fused_dropout_add_ln
:
residual
=
self
.
drop_path
(
self
.
dropout
(
hidden_states
))
+
residual
hidden_states
=
self
.
norm
(
residual
.
to
(
dtype
=
self
.
norm
.
weight
.
dtype
))
else
:
if
self
.
drop_path
.
p
==
0
or
not
self
.
training
:
rowscale
=
None
else
:
rowscale
=
self
.
drop_path
(
torch
.
ones
(
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
)
# Set prenorm=False here since we don't need to the residual
hidden_states
=
layer_norm_fn
(
hidden_states
,
self
.
norm
.
weight
,
self
.
norm
.
bias
,
residual
=
residual
,
eps
=
self
.
norm
.
eps
,
dropout_p
=
self
.
dropout
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale
,
prenorm
=
False
,
)
return
hidden_states
def
forward_head
(
self
,
x
,
pre_logits
:
bool
=
False
):
if
self
.
global_pool
:
x
=
x
[:,
self
.
num_prefix_tokens
:].
mean
(
dim
=
1
)
if
self
.
global_pool
==
"avg"
else
x
[:,
0
]
return
x
if
pre_logits
else
self
.
head
(
x
)
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
,
all_tokens
=
False
)
x
=
self
.
forward_head
(
x
)
return
x
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
patch_embed_weight
=
state_dict
[
"patch_embed.proj.weight"
]
if
patch_embed_weight
.
dim
()
==
4
:
# convert from Conv2d to Linear
state_dict
[
"patch_embed.proj.weight"
]
=
rearrange
(
patch_embed_weight
,
"o c h w -> o (c h w)"
)
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
"^blocks.(\d+).attn.qkv."
,
r
"blocks.\1.mixer.Wqkv."
,
key
)
key
=
re
.
sub
(
r
"^blocks.(\d+).attn.proj."
,
r
"blocks.\1.mixer.out_proj."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
n_layer
=
len
(
self
.
blocks
)
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
if
(
self
.
blocks
[
-
1
].
mixer
.
cross_attn
and
f
"blocks.
{
n_layer
-
1
}
.mixer.Wqkv.weight"
in
state_dict
):
Wqkv
=
state_dict
.
pop
(
f
"blocks.
{
n_layer
-
1
}
.mixer.Wqkv.weight"
)
bqkv
=
state_dict
.
pop
(
f
"blocks.
{
n_layer
-
1
}
.mixer.Wqkv.bias"
)
state_dict
[
f
"blocks.
{
n_layer
-
1
}
.mixer.Wq.weight"
]
=
Wqkv
[:
self
.
embed_dim
]
state_dict
[
f
"blocks.
{
n_layer
-
1
}
.mixer.Wkv.weight"
]
=
Wqkv
[
self
.
embed_dim
:]
state_dict
[
f
"blocks.
{
n_layer
-
1
}
.mixer.Wq.bias"
]
=
bqkv
[:
self
.
embed_dim
]
state_dict
[
f
"blocks.
{
n_layer
-
1
}
.mixer.Wkv.bias"
]
=
bqkv
[
self
.
embed_dim
:]
return
super
().
load_state_dict
(
state_dict
,
strict
=
strict
)
def
init_weights_vit_timm
(
module
:
nn
.
Module
,
name
:
str
=
""
):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if
isinstance
(
module
,
nn
.
Linear
):
trunc_normal_
(
module
.
weight
,
std
=
0.02
)
if
module
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
module
.
bias
)
elif
hasattr
(
module
,
"init_weights"
):
module
.
init_weights
()
def
vit_base_patch16_224
(
pretrained
=
False
,
**
kwargs
):
"""ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
assert
not
pretrained
model_kwargs
=
dict
(
patch_size
=
16
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
**
kwargs
)
model
=
VisionTransformer
(
**
model_kwargs
)
return
model
vllm_flash_attn/modules/__init__.py
deleted
100644 → 0
View file @
6ac8e63a
vllm_flash_attn/modules/block.py
deleted
100644 → 0
View file @
6ac8e63a
# Copyright (c) 2024, Tri Dao.
from
functools
import
partial
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torchvision.ops
import
StochasticDepth
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
Mlp
try
:
from
flash_attn.ops.triton.layer_norm
import
layer_norm_fn
,
RMSNorm
except
ImportError
:
layer_norm_fn
,
RMSNorm
=
None
,
None
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
mixer_cls
=
None
,
mlp_cls
=
None
,
norm_cls
=
nn
.
LayerNorm
,
dropout_cls
=
nn
.
Dropout
,
prenorm
=
True
,
resid_dropout1
=
0.0
,
resid_dropout2
=
0.0
,
drop_path1
=
0.0
,
drop_path2
=
0.0
,
fused_dropout_add_ln
=
False
,
return_residual
=
False
,
residual_in_fp32
=
False
,
sequence_parallel
=
False
,
mark_shared_params
=
False
,
):
"""
For prenorm=True, this Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
the hidden_states (output of the MLP) and the residual.
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
The residual needs to be provided (except for the very first block).
For prenorm=False, this Block has the same structure as a regular postnorm Transformer
block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
This is for performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
super
().
__init__
()
self
.
prenorm
=
prenorm
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
self
.
return_residual
=
return_residual
self
.
residual_in_fp32
=
residual_in_fp32
if
self
.
residual_in_fp32
:
assert
self
.
prenorm
,
"residual_in_fp32 is only compatible with prenorm=True"
if
mixer_cls
is
None
:
mixer_cls
=
partial
(
MHA
,
num_heads
=
dim
//
64
)
if
mlp_cls
is
None
:
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
4
*
dim
)
self
.
mixer
=
mixer_cls
(
dim
)
self
.
dropout1
=
dropout_cls
(
resid_dropout1
)
self
.
drop_path1
=
StochasticDepth
(
drop_path1
,
mode
=
"row"
)
self
.
norm1
=
norm_cls
(
dim
)
self
.
mlp
=
mlp_cls
(
dim
)
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
self
.
dropout2
=
dropout_cls
(
resid_dropout2
)
self
.
drop_path2
=
StochasticDepth
(
drop_path2
,
mode
=
"row"
)
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
assert
layer_norm_fn
is
not
None
,
"Triton is not installed"
assert
isinstance
(
self
.
norm1
,
(
nn
.
LayerNorm
,
RMSNorm
))
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
# For now this is not an issue because we always use sequence_parallel=True during training
# and only use sequence_parallel=False during inference.
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if
sequence_parallel
:
for
p
in
self
.
norm1
.
parameters
():
p
.
_sequence_parallel
=
True
if
hasattr
(
self
,
"norm2"
):
for
p
in
self
.
norm2
.
parameters
():
p
.
_sequence_parallel
=
True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if
mark_shared_params
:
for
p
in
self
.
norm1
.
parameters
():
p
.
_shared_params
=
True
if
hasattr
(
self
,
"norm2"
):
for
p
in
self
.
norm2
.
parameters
():
p
.
_shared_params
=
True
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
self
.
mixer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
def
forward
(
self
,
hidden_states
:
Tensor
,
residual
:
Optional
[
Tensor
]
=
None
,
mixer_subset
=
None
,
mixer_kwargs
=
None
,
):
r
"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
"""
if
self
.
prenorm
:
if
not
self
.
fused_dropout_add_ln
:
dropped
=
self
.
drop_path1
(
self
.
dropout1
(
hidden_states
))
residual
=
(
dropped
+
residual
)
if
residual
is
not
None
else
dropped
hidden_states
=
self
.
norm1
(
residual
.
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
))
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
else
:
if
self
.
drop_path1
.
p
==
0
or
not
self
.
training
:
rowscale1
=
None
else
:
rowscale1
=
self
.
drop_path1
(
torch
.
ones
(
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
)
hidden_states
,
residual
=
layer_norm_fn
(
hidden_states
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
residual
=
residual
,
eps
=
self
.
norm1
.
eps
,
dropout_p
=
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale1
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
is_rms_norm
=
isinstance
(
self
.
norm1
,
RMSNorm
)
)
if
mixer_kwargs
is
None
:
mixer_kwargs
=
{}
if
mixer_subset
is
not
None
:
mixer_kwargs
[
"mixer_subset"
]
=
mixer_subset
hidden_states
=
self
.
mixer
(
hidden_states
,
**
mixer_kwargs
)
if
mixer_subset
is
not
None
:
residual
=
residual
[:,
mixer_subset
]
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
if
not
self
.
fused_dropout_add_ln
:
dropped
=
self
.
drop_path2
(
self
.
dropout2
(
hidden_states
))
residual
=
(
dropped
+
residual
)
if
residual
is
not
None
else
dropped
hidden_states
=
self
.
norm2
(
residual
.
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
else
:
if
self
.
drop_path2
.
p
==
0
or
not
self
.
training
:
rowscale2
=
None
else
:
rowscale2
=
self
.
drop_path2
(
torch
.
ones
(
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
)
hidden_states
,
residual
=
layer_norm_fn
(
hidden_states
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
residual
=
residual
,
eps
=
self
.
norm2
.
eps
,
dropout_p
=
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale2
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
is_rms_norm
=
isinstance
(
self
.
norm2
,
RMSNorm
)
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
else
:
assert
residual
is
None
mixer_out
=
self
.
mixer
(
hidden_states
,
**
(
mixer_kwargs
if
mixer_kwargs
is
not
None
else
{})
)
if
self
.
return_residual
:
# mixer out is actually a pair here
mixer_out
,
hidden_states
=
mixer_out
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
norm1
(
(
self
.
drop_path1
(
self
.
dropout1
(
mixer_out
))
+
hidden_states
).
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
)
)
else
:
if
self
.
drop_path1
.
p
==
0
or
not
self
.
training
:
rowscale1
=
None
else
:
rowscale1
=
self
.
drop_path1
(
torch
.
ones
(
mixer_out
.
shape
[:
-
1
],
device
=
mixer_out
.
device
,
dtype
=
mixer_out
.
dtype
)
)
hidden_states
=
layer_norm_fn
(
mixer_out
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
residual
=
hidden_states
,
eps
=
self
.
norm1
.
eps
,
dropout_p
=
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale1
,
prenorm
=
False
,
is_rms_norm
=
isinstance
(
self
.
norm1
,
RMSNorm
)
)
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
mlp_out
=
self
.
mlp
(
hidden_states
)
if
self
.
return_residual
:
# mlp out is actually a pair here
mlp_out
,
hidden_states
=
mlp_out
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
norm2
(
(
self
.
drop_path2
(
self
.
dropout2
(
mlp_out
))
+
hidden_states
).
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
)
)
else
:
if
self
.
drop_path2
.
p
==
0
or
not
self
.
training
:
rowscale2
=
None
else
:
rowscale2
=
self
.
drop_path2
(
torch
.
ones
(
mlp_out
.
shape
[:
-
1
],
device
=
mlp_out
.
device
,
dtype
=
mlp_out
.
dtype
)
)
hidden_states
=
layer_norm_fn
(
mlp_out
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
residual
=
hidden_states
,
eps
=
self
.
norm2
.
eps
,
dropout_p
=
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale2
,
prenorm
=
False
,
is_rms_norm
=
isinstance
(
self
.
norm2
,
RMSNorm
)
)
return
hidden_states
class
ParallelBlock
(
nn
.
Module
):
"""The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
and PaLM.
"""
def
__init__
(
self
,
dim
,
mixer_cls
=
None
,
mlp_cls
=
None
,
norm_cls
=
nn
.
LayerNorm
,
dropout_cls
=
nn
.
Dropout
,
resid_dropout1
=
0.0
,
resid_dropout2
=
0.0
,
tied_norm
=
False
,
fused_dropout_add_ln
=
False
,
residual_in_fp32
=
False
,
sequence_parallel
=
False
,
mark_shared_params
=
False
,
):
"""
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA / MLP -> Dropout -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
the hidden_states (output1 of the MHA / MLP) and the residual.
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super
().
__init__
()
self
.
tied_norm
=
tied_norm
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
self
.
residual_in_fp32
=
residual_in_fp32
if
mixer_cls
is
None
:
mixer_cls
=
partial
(
MHA
,
num_heads
=
dim
//
64
)
if
mlp_cls
is
None
:
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
4
*
dim
)
self
.
mixer
=
mixer_cls
(
dim
)
self
.
dropout1
=
dropout_cls
(
resid_dropout1
)
self
.
norm1
=
norm_cls
(
dim
)
self
.
mlp
=
mlp_cls
(
dim
)
self
.
dropout2
=
dropout_cls
(
resid_dropout2
)
if
not
self
.
tied_norm
:
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
assert
layer_norm_fn
is
not
None
,
"Triton is not installed"
assert
isinstance
(
self
.
norm1
,
(
nn
.
LayerNorm
,
RMSNorm
))
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
# For now this is not an issue because we always use sequence_parallel=True during training
# and only use sequence_parallel=False during inference.
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if
sequence_parallel
:
for
p
in
self
.
norm1
.
parameters
():
p
.
_sequence_parallel
=
True
if
hasattr
(
self
,
"norm2"
):
for
p
in
self
.
norm2
.
parameters
():
p
.
_sequence_parallel
=
True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if
mark_shared_params
:
for
p
in
self
.
norm1
.
parameters
():
p
.
_shared_params
=
True
if
hasattr
(
self
,
"norm2"
):
for
p
in
self
.
norm2
.
parameters
():
p
.
_shared_params
=
True
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
self
.
mixer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
def
forward
(
self
,
hidden_states1
:
Tensor
,
hidden_states2
:
Optional
[
Tensor
]
=
None
,
residual
:
Optional
[
Tensor
]
=
None
,
mixer_kwargs
=
None
,
):
r
"""Pass the input through the encoder layer.
Args:
hidden_states1: the output of the previous attention (mixer) or embedding layer.
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
residual.
"""
# TODO: Ideally we should only do the allgather / allreduce once for
# the Linear to MLP & Attention
if
not
self
.
fused_dropout_add_ln
:
dropped1
=
self
.
dropout1
(
hidden_states1
)
# For the very 1st block, we only want 1 dropout, not two different dropouts
if
hidden_states2
is
not
None
:
dropped2
=
self
.
dropout2
(
hidden_states2
)
residual
=
(
(
residual
+
dropped1
+
dropped2
)
if
residual
is
not
None
else
dropped1
+
dropped2
)
else
:
residual
=
(
residual
+
dropped1
)
if
residual
is
not
None
else
dropped1
hidden_states1
=
self
.
norm1
(
residual
.
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
))
hidden_states2
=
(
self
.
norm2
(
residual
.
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
if
not
self
.
tied_norm
else
hidden_states1
)
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
else
:
weight2
,
bias2
=
(
(
self
.
norm2
.
weight
,
self
.
norm2
.
bias
)
if
not
self
.
tied_norm
else
(
None
,
None
)
)
hidden_states1
,
*
rest
,
residual
=
layer_norm_fn
(
hidden_states1
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
residual
=
residual
,
x1
=
hidden_states2
,
weight1
=
weight2
,
bias1
=
bias2
,
eps
=
self
.
norm1
.
eps
,
dropout_p
=
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
is_rms_norm
=
isinstance
(
self
.
norm1
,
RMSNorm
)
)
if
self
.
tied_norm
:
hidden_states2
=
hidden_states1
else
:
hidden_states2
,
=
rest
if
mixer_kwargs
is
None
:
mixer_kwargs
=
{}
hidden_states1
=
self
.
mixer
(
hidden_states1
,
**
mixer_kwargs
)
hidden_states2
=
self
.
mlp
(
hidden_states2
)
return
hidden_states1
,
hidden_states2
,
residual
vllm_flash_attn/modules/embedding.py
deleted
100644 → 0
View file @
6ac8e63a
# Copyright (c) 2022, Tri Dao.
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
from
torch
import
Tensor
from
flash_attn.utils.distributed
import
all_reduce
,
reduce_scatter
class
GPT2Embeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
padding_idx
=
None
,
word_embed_proj_dim
=
None
,
device
=
None
,
dtype
=
None
,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
the project up to embed_dim
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
if
word_embed_proj_dim
is
None
:
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
project_in
=
None
else
:
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
word_embed_proj_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
project_in
=
nn
.
Linear
(
word_embed_proj_dim
,
embed_dim
,
bias
=
False
,
**
factory_kwargs
)
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
,
**
factory_kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size
,
seqlen
=
input_ids
.
shape
embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
project_in
is
not
None
:
embeddings
=
self
.
project_in
(
embeddings
)
if
self
.
max_position_embeddings
>
0
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
embeddings
+
position_embeddings
return
embeddings
class
BertEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
type_vocab_size
,
padding_idx
=
None
,
device
=
None
,
dtype
=
None
,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
,
**
factory_kwargs
)
if
self
.
type_vocab_size
>
0
:
self
.
token_type_embeddings
=
nn
.
Embedding
(
type_vocab_size
,
embed_dim
,
**
factory_kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
token_type_ids: (batch, seqlen)
"""
batch_size
,
seqlen
=
input_ids
.
shape
embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
max_position_embeddings
>
0
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
embeddings
+
position_embeddings
if
self
.
type_vocab_size
>
0
:
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
embeddings
=
embeddings
+
token_type_embeddings
return
embeddings
class
VocabParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
*
args
,
process_group
=
None
,
padding_idx
=
None
,
**
kwargs
):
self
.
process_group
=
process_group
if
process_group
is
not
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
num_embeddings
%
world_size
!=
0
:
raise
ValueError
(
f
"num_embeddings (
{
num_embeddings
}
) must be divisible by "
f
"world_size (
{
world_size
}
)"
)
if
world_size
>
1
and
padding_idx
is
not
None
:
raise
RuntimeError
(
"ParallelEmbedding does not support padding_idx"
)
else
:
world_size
=
1
super
().
__init__
(
num_embeddings
//
world_size
,
*
args
,
padding_idx
=
padding_idx
,
**
kwargs
)
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
if
self
.
process_group
is
None
:
return
super
().
forward
(
input
)
else
:
rank
=
torch
.
distributed
.
get_rank
(
self
.
process_group
)
vocab_size
=
self
.
num_embeddings
vocab_start_index
,
vocab_end_index
=
rank
*
vocab_size
,
(
rank
+
1
)
*
vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
input_ids_mask
=
(
input
<
vocab_start_index
)
|
(
input
>=
vocab_end_index
)
input
=
input
-
vocab_start_index
input
[
input_ids_mask
]
=
0
embeddings
=
super
().
forward
(
input
)
embeddings
[
input_ids_mask
]
=
0.0
return
embeddings
class
ColumnParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
*
args
,
process_group
=
None
,
**
kwargs
):
self
.
process_group
=
process_group
if
process_group
is
not
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
embedding_dim
%
world_size
!=
0
:
raise
ValueError
(
f
"embedding_dim (
{
embedding_dim
}
) must be divisible by "
f
"world_size (
{
world_size
}
)"
)
else
:
world_size
=
1
super
().
__init__
(
num_embeddings
,
embedding_dim
//
world_size
,
*
args
,
**
kwargs
)
class
ParallelGPT2Embeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
process_group
,
padding_idx
=
None
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
word_embeddings
=
VocabParallelEmbedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
process_group
=
process_group
,
**
factory_kwargs
,
)
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
ColumnParallelEmbedding
(
max_position_embeddings
,
embed_dim
,
process_group
=
process_group
,
**
factory_kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
combine_batch_seqlen_dim
=
False
):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size
,
seqlen
=
input_ids
.
shape
world_size
=
torch
.
distributed
.
get_world_size
(
self
.
process_group
)
embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
max_position_embeddings
>
0
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
if
world_size
<=
1
:
embeddings
=
embeddings
+
position_embeddings
else
:
partition_dim
=
self
.
position_embeddings
.
embedding_dim
rank
=
torch
.
distributed
.
get_rank
(
self
.
process_group
)
embeddings
[
...,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
]
+=
position_embeddings
if
combine_batch_seqlen_dim
:
embeddings
=
rearrange
(
embeddings
,
"b s d -> (b s) d"
)
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
return
embeddings
if
world_size
<=
1
else
reduce_fn
(
embeddings
,
self
.
process_group
)
vllm_flash_attn/modules/mha.py
deleted
100644 → 0
View file @
6ac8e63a
# Copyright (c) 2023, Tri Dao.
import
math
from
functools
import
partial
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
,
repeat
from
flash_attn.utils.distributed
import
get_dim_for_local_rank
try
:
from
flash_attn
import
(
flash_attn_kvpacked_func
,
flash_attn_qkvpacked_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_with_kvcache
,
)
except
ImportError
:
flash_attn_varlen_qkvpacked_func
,
flash_attn_varlen_kvpacked_func
=
None
,
None
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
=
None
,
None
flash_attn_with_kvcache
=
None
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
FusedDense
,
RowParallelLinear
except
ImportError
:
FusedDense
,
ColumnParallelLinear
,
RowParallelLinear
=
None
,
None
,
None
try
:
from
flash_attn.layers.rotary
import
RotaryEmbedding
except
ImportError
:
RotaryEmbedding
=
None
# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
def
get_alibi_slopes
(
nheads
):
def
get_slopes_power_of_2
(
nheads
):
start
=
2
**
(
-
(
2
**
-
(
math
.
log2
(
nheads
)
-
3
)))
ratio
=
start
return
[
start
*
ratio
**
i
for
i
in
range
(
nheads
)]
if
math
.
log2
(
nheads
).
is_integer
():
return
get_slopes_power_of_2
(
nheads
)
else
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
nheads
))
return
(
get_slopes_power_of_2
(
closest_power_of_2
)
+
get_alibi_slopes
(
2
*
closest_power_of_2
)[
0
::
2
][:
nheads
-
closest_power_of_2
]
)
class
FlashSelfAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
window_size
=
(
-
1
,
-
1
),
alibi_slopes
=
None
,
deterministic
=
False
,
):
super
().
__init__
()
assert
flash_attn_varlen_qkvpacked_func
is
not
None
,
"FlashAttention is not installed"
assert
flash_attn_qkvpacked_func
is
not
None
,
"FlashAttention is not installed"
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
self
.
register_buffer
(
"alibi_slopes"
,
alibi_slopes
,
persistent
=
False
)
self
.
window_size
=
window_size
self
.
deterministic
=
deterministic
def
forward
(
self
,
qkv
,
causal
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value.
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
Returns:
--------
out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
else (B, S, H, D).
"""
assert
qkv
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
qkv
.
is_cuda
causal
=
self
.
causal
if
causal
is
None
else
causal
unpadded
=
cu_seqlens
is
not
None
if
self
.
alibi_slopes
is
not
None
:
self
.
alibi_slopes
=
self
.
alibi_slopes
.
to
(
torch
.
float32
)
if
unpadded
:
assert
cu_seqlens
.
dtype
==
torch
.
int32
assert
max_seqlen
is
not
None
assert
isinstance
(
max_seqlen
,
int
)
return
flash_attn_varlen_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
window_size
,
deterministic
=
self
.
deterministic
,
)
else
:
return
flash_attn_qkvpacked_func
(
qkv
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
window_size
,
deterministic
=
self
.
deterministic
,
)
class
FlashCrossAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
alibi_slopes
=
None
,
window_size
=
(
-
1
,
-
1
),
deterministic
=
False
,
):
super
().
__init__
()
assert
flash_attn_varlen_kvpacked_func
is
not
None
,
"FlashAttention is not installed"
assert
flash_attn_kvpacked_func
is
not
None
,
"FlashAttention is not installed"
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
self
.
register_buffer
(
"alibi_slopes"
,
alibi_slopes
,
persistent
=
False
)
self
.
window_size
=
window_size
self
.
deterministic
=
deterministic
def
forward
(
self
,
q
,
kv
,
causal
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
cu_seqlens_k
=
None
,
max_seqlen_k
=
None
,
):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
max_seqlen: int. Maximum sequence length in the batch of q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_k: int. Maximum sequence length in the batch of k and v.
"""
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
q
.
is_cuda
and
kv
.
is_cuda
causal
=
self
.
causal
if
causal
is
None
else
causal
unpadded
=
cu_seqlens
is
not
None
if
self
.
alibi_slopes
is
not
None
:
self
.
alibi_slopes
=
self
.
alibi_slopes
.
to
(
torch
.
float32
)
if
unpadded
:
assert
cu_seqlens
.
dtype
==
torch
.
int32
assert
max_seqlen
is
not
None
assert
isinstance
(
max_seqlen
,
int
)
assert
cu_seqlens_k
is
not
None
assert
cu_seqlens_k
.
dtype
==
torch
.
int32
assert
max_seqlen_k
is
not
None
assert
isinstance
(
max_seqlen
,
int
)
return
flash_attn_varlen_kvpacked_func
(
q
,
kv
,
cu_seqlens
,
cu_seqlens_k
,
max_seqlen
,
max_seqlen_k
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
window_size
,
deterministic
=
self
.
deterministic
,
)
else
:
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
seqlen_k
=
kv
.
shape
[
1
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
return
flash_attn_kvpacked_func
(
q
,
kv
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
causal
=
causal
,
softmax_scale
=
self
.
softmax_scale
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
window_size
,
deterministic
=
self
.
deterministic
,
)
class
SelfAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
super
().
__init__
()
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
def
forward
(
self
,
qkv
,
causal
=
None
,
key_padding_mask
=
None
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, S)
"""
batch_size
,
seqlen
=
qkv
.
shape
[
0
],
qkv
.
shape
[
1
]
causal
=
self
.
causal
if
causal
is
None
else
causal
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
softmax_scale
=
self
.
softmax_scale
or
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
*
softmax_scale
)
if
key_padding_mask
is
not
None
:
padding_mask
=
torch
.
full
(
(
batch_size
,
seqlen
),
-
10000.0
,
dtype
=
scores
.
dtype
,
device
=
scores
.
device
)
padding_mask
.
masked_fill_
(
key_padding_mask
,
0.0
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
rearrange
(
padding_mask
,
"b s -> b 1 1 s"
)
if
causal
:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen
,
seqlen
),
-
10000.0
,
device
=
scores
.
device
),
1
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
causal_mask
.
to
(
dtype
=
scores
.
dtype
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
v
.
dtype
)
attention_drop
=
self
.
drop
(
attention
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention_drop
,
v
)
return
output
class
CrossAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
super
().
__init__
()
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
def
forward
(
self
,
q
,
kv
,
causal
=
None
,
key_padding_mask
=
None
):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk)
"""
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
causal
=
self
.
causal
if
causal
is
None
else
causal
seqlen_k
=
kv
.
shape
[
1
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
if
kv
.
shape
[
3
]
!=
q
.
shape
[
2
]:
# MQA/GQA
kv
=
repeat
(
kv
,
"... hkv d -> ... (hkv g) d"
,
g
=
q
.
shape
[
2
]
//
kv
.
shape
[
3
])
k
,
v
=
kv
.
unbind
(
dim
=
2
)
softmax_scale
=
self
.
softmax_scale
or
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
*
softmax_scale
)
if
key_padding_mask
is
not
None
:
padding_mask
=
torch
.
full
(
(
batch_size
,
seqlen_k
),
-
10000.0
,
dtype
=
scores
.
dtype
,
device
=
scores
.
device
)
padding_mask
.
masked_fill_
(
key_padding_mask
,
0.0
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
rearrange
(
padding_mask
,
"b s -> b 1 1 s"
)
if
causal
:
# causal mask needs to take into account the difference between seqlen_q and seqlen_k
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
q
.
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
kv
.
device
,
dtype
=
torch
.
long
)
sk
=
(
seqlen_k
if
key_padding_mask
is
None
else
rearrange
(
key_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
causal_mask
=
col_idx
>
row_idx
+
sk
-
seqlen_q
scores
=
scores
.
masked_fill
(
causal_mask
,
-
10000.0
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
v
.
dtype
)
attention_drop
=
self
.
drop
(
attention
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention_drop
,
v
)
return
output
class
LinearResidual
(
nn
.
Linear
):
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
super
().
forward
(
input
),
input
def
_update_kv_cache
(
kv
,
inference_params
,
layer_idx
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
# Pre-allocate memory for key-values for inference.
num_heads
,
head_dim
=
kv
.
shape
[
-
2
:]
if
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
kv_cache
=
torch
.
empty
(
inference_params
.
max_batch_size
,
inference_params
.
max_seqlen
,
2
,
num_heads
,
head_dim
,
dtype
=
kv
.
dtype
,
device
=
kv
.
device
,
)
inference_params
.
key_value_memory_dict
[
layer_idx
]
=
kv_cache
else
:
kv_cache
=
inference_params
.
key_value_memory_dict
[
layer_idx
]
# Adjust key and value for inference
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
kv
.
shape
[
0
]
sequence_start
=
inference_params
.
seqlen_offset
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
assert
batch_end
<=
kv_cache
.
shape
[
0
]
assert
sequence_end
<=
kv_cache
.
shape
[
1
]
assert
kv_cache
is
not
None
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
return
kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
class
MHA
(
nn
.
Module
):
"""Multi-head self-attention and cross-attention"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
num_heads_kv
=
None
,
cross_attn
=
False
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
dwconv
=
False
,
rotary_emb_dim
=
0
,
rotary_emb_base
=
10000.0
,
rotary_emb_scale_base
=
None
,
rotary_emb_interleaved
=
False
,
use_alibi
=
False
,
window_size
=
(
-
1
,
-
1
),
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
return_residual
=
False
,
checkpointing
=
False
,
device
=
None
,
dtype
=
None
,
)
->
None
:
"""
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
cross_attn
=
cross_attn
self
.
causal
=
causal
self
.
layer_idx
=
layer_idx
self
.
dwconv
=
dwconv
self
.
rotary_emb_dim
=
rotary_emb_dim
self
.
use_flash_attn
=
use_flash_attn
self
.
return_residual
=
return_residual
self
.
checkpointing
=
checkpointing
if
use_alibi
:
assert
use_flash_attn
,
"ALiBi code path requires flash_attn"
alibi_slopes
=
torch
.
tensor
(
get_alibi_slopes
(
num_heads
),
device
=
device
)
else
:
alibi_slopes
=
None
if
window_size
!=
(
-
1
,
-
1
):
assert
use_flash_attn
,
"Local (sliding window) attention code path requires flash_attn"
self
.
num_heads
=
num_heads
self
.
num_heads_kv
=
num_heads_kv
if
num_heads_kv
is
not
None
else
num_heads
assert
(
self
.
num_heads
%
self
.
num_heads_kv
==
0
),
"num_heads must be divisible by num_heads_kv"
assert
self
.
embed_dim
%
num_heads
==
0
,
"embed_dim must be divisible by num_heads"
self
.
head_dim
=
self
.
embed_dim
//
num_heads
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
kv_dim
=
2
*
self
.
head_dim
*
self
.
num_heads_kv
if
self
.
rotary_emb_dim
>
0
:
assert
not
cross_attn
,
"MHA with rotary embedding does not support cross-attention yet"
assert
RotaryEmbedding
is
not
None
,
"rotary_emb is not installed"
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
base
=
rotary_emb_base
,
scale_base
=
rotary_emb_scale_base
,
interleaved
=
rotary_emb_interleaved
,
device
=
device
,
)
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
"fused_dense is not installed"
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
linear_resid_cls
=
(
LinearResidual
if
not
fused_bias_fc
else
partial
(
FusedDense
,
return_residual
=
True
)
)
wqkv_cls
=
linear_cls
if
not
self
.
return_residual
else
linear_resid_cls
inner_attn_cls
=
(
partial
(
FlashSelfAttention
,
alibi_slopes
=
alibi_slopes
,
window_size
=
window_size
)
if
use_flash_attn
else
SelfAttention
)
inner_cross_attn_cls
=
(
partial
(
FlashCrossAttention
,
alibi_slopes
=
alibi_slopes
,
window_size
=
window_size
)
if
use_flash_attn
else
CrossAttention
)
if
not
self
.
cross_attn
:
self
.
Wqkv
=
wqkv_cls
(
embed_dim
,
qkv_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
else
:
self
.
Wq
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
self
.
Wkv
=
wqkv_cls
(
embed_dim
,
kv_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
if
self
.
dwconv
:
if
self
.
num_heads_kv
==
self
.
num_heads
:
self
.
dwconv_qkv
=
nn
.
Conv1d
(
qkv_dim
,
qkv_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
qkv_dim
)
else
:
self
.
dwconv_q
=
nn
.
Conv1d
(
embed_dim
,
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
embed_dim
)
self
.
dwconv_kv
=
nn
.
Conv1d
(
kv_dim
,
kv_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
kv_dim
)
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
,
)
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
out_proj_bias
,
**
factory_kwargs
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
):
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
device
=
self
.
out_proj
.
weight
.
device
return
torch
.
empty
(
batch_size
,
max_seqlen
,
2
,
self
.
num_heads_kv
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
,
)
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
assert
not
self
.
dwconv
,
"Generation does not support dwconv yet"
assert
self
.
layer_idx
is
not
None
,
"Generation requires layer_idx in the constructor"
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
def
_apply_rotary_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
assert
inference_params
is
not
None
and
inference_params
.
seqlen_offset
>
0
assert
self
.
use_flash_attn
if
self
.
rotary_emb_dim
>
0
:
assert
self
.
rotary_emb
.
scale
is
None
,
"This code path does not support xPos"
self
.
rotary_emb
.
_update_cos_sin_cache
(
inference_params
.
max_seqlen
,
device
=
q
.
device
,
dtype
=
q
.
dtype
)
rotary_cos
,
rotary_sin
=
self
.
rotary_emb
.
_cos_cached
,
self
.
rotary_emb
.
_sin_cached
else
:
rotary_cos
,
rotary_sin
=
None
,
None
batch
=
q
.
shape
[
0
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
][:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
)
alibi_slopes
=
getattr
(
self
.
inner_cross_attn
,
"alibi_slopes"
,
None
)
context
=
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
rotary_cos
=
rotary_cos
,
rotary_sin
=
rotary_sin
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
rotary_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
alibi_slopes
=
alibi_slopes
,
)
return
context
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention"""
if
(
inference_params
.
seqlen_offset
==
0
or
flash_attn_with_kvcache
is
None
or
not
self
.
use_flash_attn
):
# TODO: this only uses seqlen_offset and not lengths_per_sample.
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
return
self
.
inner_cross_attn
(
q
,
kv
)
else
:
batch
=
q
.
shape
[
0
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
][:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
)
alibi_slopes
=
getattr
(
self
.
inner_cross_attn
,
"alibi_slopes"
,
None
)
return
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
alibi_slopes
=
alibi_slopes
,
)
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
mixer_subset
=
None
,
inference_params
=
None
,
**
kwargs
,
):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
is the is the sum of the sequence lengths in the batch.
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into x. Only applicable when using
FlashAttention.
max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
if
cu_seqlens
is
not
None
:
assert
max_seqlen
is
not
None
assert
key_padding_mask
is
None
assert
self
.
use_flash_attn
assert
not
self
.
dwconv
assert
self
.
rotary_emb_dim
==
0
if
key_padding_mask
is
not
None
:
assert
cu_seqlens
is
None
assert
max_seqlen
is
None
assert
not
self
.
use_flash_attn
if
inference_params
is
not
None
:
assert
key_padding_mask
is
None
assert
cu_seqlens
is
None
and
max_seqlen
is
None
assert
not
self
.
dwconv
kwargs
=
(
{
"cu_seqlens"
:
cu_seqlens
,
"max_seqlen"
:
max_seqlen
,
**
kwargs
}
if
self
.
use_flash_attn
else
{
"key_padding_mask"
:
key_padding_mask
,
**
kwargs
}
)
seqlen_offset
=
(
0
if
inference_params
is
None
else
(
inference_params
.
lengths_per_sample
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
)
)
rotary_max_seqlen
=
inference_params
.
max_seqlen
if
inference_params
is
not
None
else
None
batch
,
seqlen
=
x
.
shape
[:
2
]
if
not
self
.
cross_attn
and
self
.
num_heads_kv
==
self
.
num_heads
:
assert
x_kv
is
None
and
mixer_subset
is
None
if
not
self
.
return_residual
:
qkv
=
self
.
Wqkv
(
x
)
else
:
qkv
,
x
=
self
.
Wqkv
(
x
)
if
self
.
dwconv
:
qkv
=
rearrange
(
self
.
dwconv_qkv
(
rearrange
(
qkv
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
).
contiguous
()
qkv
=
rearrange
(
qkv
,
"... (three h d) -> ... three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
if
(
inference_params
is
None
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
seqlen_offset
,
max_seqlen
=
rotary_max_seqlen
)
if
inference_params
is
None
:
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
context
=
self
.
_update_kvcache_attention
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
else
:
context
=
self
.
_apply_rotary_update_kvcache_attention
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
else
:
if
self
.
cross_attn
:
if
not
self
.
return_residual
:
q
=
self
.
Wq
(
x
if
mixer_subset
is
None
else
x
[:,
mixer_subset
])
kv
=
self
.
Wkv
(
x_kv
if
x_kv
is
not
None
else
x
)
else
:
if
x_kv
is
not
None
:
kv
,
x_kv
=
self
.
Wkv
(
x_kv
)
else
:
kv
,
x
=
self
.
Wkv
(
x
)
q
=
self
.
Wq
(
x
if
mixer_subset
is
None
else
x
[:,
mixer_subset
])
else
:
assert
self
.
num_heads_kv
!=
self
.
num_heads
if
not
self
.
return_residual
:
qkv
=
self
.
Wqkv
(
x
)
else
:
qkv
,
x
=
self
.
Wqkv
(
x
)
q
=
qkv
[...,
:
self
.
num_heads
*
self
.
head_dim
]
kv
=
qkv
[...,
self
.
num_heads
*
self
.
head_dim
:]
q
=
rearrange
(
q
,
"... (h d) -> ... h d"
,
d
=
self
.
head_dim
)
kv
=
rearrange
(
kv
,
"... (two hkv d) -> ... two hkv d"
,
two
=
2
,
d
=
self
.
head_dim
)
if
self
.
dwconv
:
q
=
rearrange
(
self
.
dwconv_q
(
rearrange
(
q
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
).
contiguous
()
kv
=
rearrange
(
self
.
dwconv_kv
(
rearrange
(
kv
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
).
contiguous
()
if
(
inference_params
is
None
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
if
self
.
rotary_emb_dim
>
0
:
q
,
kv
=
self
.
rotary_emb
(
q
,
kv
,
seqlen_offset
=
seqlen_offset
,
max_seqlen
=
rotary_max_seqlen
)
if
inference_params
is
None
:
if
not
self
.
checkpointing
:
context
=
self
.
inner_cross_attn
(
q
,
kv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_cross_attn
,
q
,
kv
,
**
kwargs
)
else
:
context
=
self
.
_update_kvcache_attention
(
q
,
kv
,
inference_params
)
else
:
context
=
self
.
_apply_rotary_update_kvcache_attention
(
q
,
kv
,
inference_params
)
out
=
self
.
out_proj
(
rearrange
(
context
,
"... h d -> ... (h d)"
))
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
class
ParallelMHA
(
nn
.
Module
):
"""Multi-head self-attention and cross-attention"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
process_group
,
num_heads_kv
=
None
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
rotary_emb_dim
=
0
,
rotary_emb_base
=
10000.0
,
rotary_emb_scale_base
=
None
,
rotary_emb_interleaved
=
False
,
use_alibi
=
False
,
window_size
=
(
-
1
,
-
1
),
use_flash_attn
=
False
,
checkpointing
=
False
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
,
)
->
None
:
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
causal
=
causal
self
.
layer_idx
=
layer_idx
self
.
rotary_emb_dim
=
rotary_emb_dim
self
.
use_flash_attn
=
use_flash_attn
self
.
checkpointing
=
checkpointing
self
.
process_group
=
process_group
self
.
world_size
=
process_group
.
size
()
self
.
local_rank
=
torch
.
distributed
.
get_rank
(
process_group
)
self
.
num_heads
=
num_heads
assert
self
.
embed_dim
%
self
.
num_heads
==
0
,
"embed_dim must be divisible by num_heads"
self
.
num_heads_kv
=
num_heads_kv
if
num_heads_kv
is
not
None
else
num_heads
assert
(
self
.
num_heads
%
self
.
num_heads_kv
==
0
),
"num_heads must be divisible by num_heads_kv"
self
.
num_heads_per_rank
=
get_dim_for_local_rank
(
self
.
num_heads
,
self
.
world_size
,
self
.
local_rank
)
self
.
num_heads_kv_per_rank
=
get_dim_for_local_rank
(
self
.
num_heads_kv
,
self
.
world_size
,
self
.
local_rank
)
self
.
head_dim
=
self
.
embed_dim
//
num_heads
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
if
use_alibi
:
assert
use_flash_attn
,
"ALiBi code path requires flash_attn"
num_heads_local
=
math
.
ceil
(
self
.
num_heads
/
self
.
world_size
)
alibi_slopes
=
torch
.
tensor
(
get_alibi_slopes
(
num_heads
)[
self
.
local_rank
*
num_heads_local
:
(
self
.
local_rank
+
1
)
*
num_heads_local
],
device
=
device
,
)
else
:
alibi_slopes
=
None
if
window_size
!=
(
-
1
,
-
1
):
assert
use_flash_attn
,
"Local (sliding window) attention code path requires flash_attn"
if
self
.
rotary_emb_dim
>
0
:
assert
RotaryEmbedding
is
not
None
,
"rotary_emb is not installed"
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
base
=
rotary_emb_base
,
scale_base
=
rotary_emb_scale_base
,
interleaved
=
rotary_emb_interleaved
,
device
=
device
,
)
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
raise
ImportError
(
"fused_dense is not installed"
)
self
.
Wqkv
=
ColumnParallelLinear
(
embed_dim
,
qkv_dim
,
process_group
,
bias
=
qkv_proj_bias
,
sequence_parallel
=
sequence_parallel
,
multiple_of
=
self
.
head_dim
*
(
self
.
num_heads
//
self
.
num_heads_kv
+
2
),
**
factory_kwargs
,
)
inner_attn_cls
=
(
partial
(
FlashSelfAttention
,
alibi_slopes
=
alibi_slopes
,
window_size
=
window_size
)
if
use_flash_attn
else
SelfAttention
)
inner_cross_attn_cls
=
(
partial
(
FlashCrossAttention
,
alibi_slopes
=
alibi_slopes
,
window_size
=
window_size
)
if
use_flash_attn
else
CrossAttention
)
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
process_group
,
bias
=
out_proj_bias
,
sequence_parallel
=
sequence_parallel
,
multiple_of
=
self
.
head_dim
,
**
factory_kwargs
,
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
):
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
device
=
self
.
out_proj
.
weight
.
device
return
torch
.
empty
(
batch_size
,
max_seqlen
,
2
,
self
.
num_heads_kv_per_rank
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
,
)
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
assert
self
.
layer_idx
is
not
None
,
"Generation requires layer_idx in the constructor"
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
def
_apply_rotary_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
assert
inference_params
is
not
None
and
inference_params
.
seqlen_offset
>
0
assert
self
.
use_flash_attn
if
self
.
rotary_emb_dim
>
0
:
assert
self
.
rotary_emb
.
scale
is
None
,
"This code path does not support xPos"
self
.
rotary_emb
.
_update_cos_sin_cache
(
inference_params
.
max_seqlen
,
device
=
q
.
device
,
dtype
=
q
.
dtype
)
rotary_cos
,
rotary_sin
=
self
.
rotary_emb
.
_cos_cached
,
self
.
rotary_emb
.
_sin_cached
else
:
rotary_cos
,
rotary_sin
=
None
,
None
batch
=
q
.
shape
[
0
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
][:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
)
alibi_slopes
=
getattr
(
self
.
inner_cross_attn
,
"alibi_slopes"
,
None
)
context
=
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
rotary_cos
=
rotary_cos
,
rotary_sin
=
rotary_sin
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
rotary_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
alibi_slopes
=
alibi_slopes
,
)
return
context
def
_update_kvcache_attention
(
self
,
q
,
kv
,
inference_params
):
"""Write kv to inference_params, then do attention"""
if
inference_params
.
seqlen_offset
==
0
or
not
self
.
use_flash_attn
:
# TODO: this only uses seqlen_offset and not lengths_per_sample.
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
return
self
.
inner_cross_attn
(
q
,
kv
)
else
:
batch
=
q
.
shape
[
0
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
][:
batch
]
cache_seqlens
=
(
inference_params
.
lengths_per_sample
[:
batch
]
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
)
alibi_slopes
=
getattr
(
self
.
inner_cross_attn
,
"alibi_slopes"
,
None
)
context
=
flash_attn_with_kvcache
(
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
1
],
kv
[:,
:,
0
],
kv
[:,
:,
1
],
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
alibi_slopes
=
alibi_slopes
,
)
return
context
def
forward
(
self
,
x
,
seqlen
=
None
,
inference_params
=
None
,
**
kwargs
):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
split x during sequence parallel, we split the batch * seqlen dimension
(in case batch is small).
"""
qkv
=
self
.
Wqkv
(
x
)
if
seqlen
is
not
None
:
qkv
=
rearrange
(
qkv
,
"(b s) ... -> b s ..."
,
s
=
seqlen
)
seqlen_offset
=
(
0
if
inference_params
is
None
else
(
inference_params
.
lengths_per_sample
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
)
)
rotary_max_seqlen
=
inference_params
.
max_seqlen
if
inference_params
is
not
None
else
None
if
self
.
num_heads_kv
==
self
.
num_heads
:
qkv
=
rearrange
(
qkv
,
"b s (three h d) -> b s three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
if
(
inference_params
is
None
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
seqlen_offset
,
max_seqlen
=
rotary_max_seqlen
)
if
inference_params
is
None
:
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
context
=
self
.
_update_kvcache_attention
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
else
:
context
=
self
.
_apply_rotary_update_kvcache_attention
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
:],
inference_params
)
else
:
q
=
rearrange
(
qkv
[...,
:
self
.
num_heads_per_rank
*
self
.
head_dim
],
"... (h d) -> ... h d"
,
d
=
self
.
head_dim
,
)
kv
=
rearrange
(
qkv
[...,
self
.
num_heads_per_rank
*
self
.
head_dim
:],
"... (two hkv d) -> ... two hkv d"
,
two
=
2
,
d
=
self
.
head_dim
,
)
if
(
inference_params
is
None
or
inference_params
.
seqlen_offset
==
0
or
(
self
.
rotary_emb_dim
==
0
or
self
.
rotary_emb_dim
%
16
!=
0
)
or
not
self
.
use_flash_attn
):
if
self
.
rotary_emb_dim
>
0
:
q
,
kv
=
self
.
rotary_emb
(
q
,
kv
,
seqlen_offset
=
seqlen_offset
,
max_seqlen
=
rotary_max_seqlen
)
if
inference_params
is
None
:
if
not
self
.
checkpointing
:
context
=
self
.
inner_cross_attn
(
q
,
kv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_cross_attn
,
q
,
kv
,
**
kwargs
)
else
:
context
=
self
.
_update_kvcache_attention
(
q
,
kv
,
inference_params
)
else
:
context
=
self
.
_apply_rotary_update_kvcache_attention
(
q
,
kv
,
inference_params
)
context
=
rearrange
(
context
,
"b s h d -> b s (h d)"
)
if
seqlen
is
not
None
:
context
=
rearrange
(
context
,
"b s d -> (b s) d"
)
out
=
self
.
out_proj
(
context
)
return
out
vllm_flash_attn/modules/mlp.py
deleted
100644 → 0
View file @
6ac8e63a
# Copyright (c) 2023, Tri Dao.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.distributed
import
ProcessGroup
try
:
from
flash_attn.ops.activations
import
swiglu
except
ImportError
:
swiglu
=
None
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
RowParallelLinear
except
ImportError
:
ColumnParallelLinear
,
RowParallelLinear
=
None
,
None
try
:
from
flash_attn.ops.fused_dense
import
FusedMLP
,
ParallelFusedMLP
except
ImportError
:
FusedMLP
,
ParallelFusedMLP
=
None
,
None
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
gelu
,
bias1
=
True
,
bias2
=
True
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
hidden_features
if
hidden_features
is
not
None
else
in_features
*
4
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
activation
=
activation
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
y
=
self
.
activation
(
y
)
y
=
self
.
fc2
(
y
)
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
class
ParallelMLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
gelu
,
process_group
:
ProcessGroup
=
None
,
sequence_parallel
=
True
,
bias1
=
True
,
bias2
=
True
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
assert
ColumnParallelLinear
is
not
None
,
"Need to install fused_dense"
assert
RowParallelLinear
is
not
None
,
"Need to install fused_dense"
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
hidden_features
if
hidden_features
is
not
None
else
in_features
*
4
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
process_group
,
bias
=
bias1
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
self
.
activation
=
activation
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
y
=
self
.
activation
(
y
)
y
=
self
.
fc2
(
y
)
return
y
class
GatedMlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
sigmoid
,
bias1
=
True
,
bias2
=
True
,
multiple_of
=
128
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
(
hidden_features
if
hidden_features
is
not
None
else
int
(
8
*
in_features
/
3
)
)
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
2
*
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
activation
=
activation
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
if
self
.
activation
==
F
.
sigmoid
:
# Special case for GLU
y
=
F
.
glu
(
y
,
dim
=-
1
)
elif
self
.
activation
==
F
.
silu
and
swiglu
is
not
None
:
# Special case for SwiGLU
y
,
gate
=
y
.
chunk
(
2
,
dim
=-
1
)
y
=
swiglu
(
gate
,
y
)
else
:
y
,
gate
=
y
.
chunk
(
2
,
dim
=-
1
)
y
=
y
*
self
.
activation
(
gate
)
y
=
self
.
fc2
(
y
)
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
class
ParallelGatedMlp
(
nn
.
Module
):
"""Parallel GatedMlp"""
def
__init__
(
self
,
in_features
,
process_group
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
sigmoid
,
bias1
=
True
,
bias2
=
True
,
multiple_of
=
128
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
(
hidden_features
if
hidden_features
is
not
None
else
int
(
8
*
in_features
/
3
)
)
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
raise
ImportError
(
"fused_dense is not installed"
)
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
2
*
hidden_features
,
process_group
,
bias
=
bias1
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
self
.
activation
=
activation
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
if
self
.
activation
==
F
.
sigmoid
:
# Special case for GLU
y
=
F
.
glu
(
y
,
dim
=-
1
)
else
:
y
,
gate
=
y
.
chunk
(
2
,
dim
=-
1
)
y
=
y
*
self
.
activation
(
gate
)
y
=
self
.
fc2
(
y
)
return
y
vllm_flash_attn/ops/__init__.py
deleted
100644 → 0
View file @
6ac8e63a
vllm_flash_attn/ops/activations.py
deleted
100644 → 0
View file @
6ac8e63a
# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@
torch
.
jit
.
script
def
bias_gelu
(
y
,
bias
):
x
=
bias
+
y
return
(
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))).
to
(
dtype
=
y
.
dtype
)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
torch
.
jit
.
script
def
bias_gelu_back
(
g
,
y
,
bias
):
"""Assume that y has shape (B, D) and bias has shape (D)"""
x
=
bias
+
y
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
grad_y
=
ff
*
g
return
grad_y
.
to
(
dtype
=
y
.
dtype
),
grad_y
.
sum
(
dim
=
(
0
),
dtype
=
bias
.
dtype
)
class
GeLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
bias
):
ctx
.
save_for_backward
(
input
,
bias
)
return
bias_gelu
(
input
,
bias
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
bias
=
ctx
.
saved_tensors
tmp
=
bias_gelu_back
(
grad_output
,
input
,
bias
)
return
tmp
,
tmp
bias_gelu_impl
=
GeLUFunction
.
apply
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@
torch
.
jit
.
script
def
gelu_fwd
(
x
):
return
(
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))).
to
(
dtype
=
x
.
dtype
)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
torch
.
jit
.
script
def
gelu_bwd
(
g
,
x
):
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
return
(
ff
*
g
).
to
(
dtype
=
x
.
dtype
)
class
FastGeLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
):
ctx
.
save_for_backward
(
input
)
return
gelu_fwd
(
input
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
(
input
,)
=
ctx
.
saved_tensors
tmp
=
gelu_bwd
(
grad_output
,
input
)
return
tmp
fast_gelu_impl
=
FastGeLUFunction
.
apply
@
torch
.
jit
.
script
def
relu_bwd
(
g
,
x
):
return
torch
.
where
(
x
>=
0
,
g
,
0.0
).
to
(
dtype
=
x
.
dtype
)
@
torch
.
jit
.
script
def
sqrelu_fwd
(
x
):
r
=
F
.
relu
(
x
)
return
(
r
*
r
).
to
(
dtype
=
x
.
dtype
)
@
torch
.
jit
.
script
def
sqrelu_bwd
(
g
,
x
):
return
(
2.0
*
g
*
F
.
relu
(
x
)).
to
(
dtype
=
x
.
dtype
)
swiglu_fwd_codestring
=
"""
template <typename T> T swiglu_fwd(T x, T y) {
return float(x) * float(y) / (1.0f + ::exp(-float(x)));
}
"""
swiglu_bwd_codestring
=
"""
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
dy = float(x) * x_sigmoid * float(g);
}
"""
swiglu_fwd
=
torch
.
cuda
.
jiterator
.
_create_jit_fn
(
swiglu_fwd_codestring
)
swiglu_bwd
=
torch
.
cuda
.
jiterator
.
_create_multi_output_jit_fn
(
swiglu_bwd_codestring
,
num_outputs
=
2
)
class
SwiGLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
y
):
ctx
.
save_for_backward
(
x
,
y
)
return
swiglu_fwd
(
x
,
y
)
@
staticmethod
def
backward
(
ctx
,
dout
):
x
,
y
=
ctx
.
saved_tensors
return
swiglu_bwd
(
x
,
y
,
dout
)
swiglu
=
SwiGLUFunction
.
apply
vllm_flash_attn/ops/fused_dense.py
deleted
100644 → 0
View file @
6ac8e63a
# Copyright (c) 2023, Tri Dao.
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# We make it work with pytorch amp and with bfloat16.
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
from
functools
import
partial
from
typing
import
Optional
# import fused_dense_cuda # from apex
import
fused_dense_lib
as
fused_dense_cuda
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.distributed
import
ProcessGroup
from
flash_attn.ops.activations
import
gelu_bwd
,
relu_bwd
,
sqrelu_bwd
,
sqrelu_fwd
from
flash_attn.utils.distributed
import
(
all_gather_raw
,
all_reduce
,
all_reduce_raw
,
reduce_scatter
,
reduce_scatter_raw
,
)
class
FusedDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight
,
bias
,
return_residual
=
False
,
process_group
=
None
,
sequence_parallel
=
True
):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
"""
ctx
.
compute_weight_gradient
=
weight
.
requires_grad
ctx
.
return_residual
=
return_residual
ctx
.
process_group
=
process_group
ctx
.
sequence_parallel
=
sequence_parallel
if
torch
.
is_autocast_enabled
():
x
=
x
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
x
=
x
.
contiguous
()
if
process_group
is
not
None
and
sequence_parallel
:
# We want to kick off the all_gather early, before weight dtype conversion
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
else
:
total_x
=
x
if
torch
.
is_autocast_enabled
():
weight
=
weight
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
bias
=
bias
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
if
bias
is
not
None
else
None
weight
=
weight
.
contiguous
()
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
batch_shape
,
n
=
total_x
.
shape
[:
-
1
],
total_x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if
min
(
batch_dim
,
n
,
*
weight
.
shape
)
>
65535
*
32
:
raise
RuntimeError
(
"fused_dense only supports matrix dims <= 2M"
)
output
=
F
.
linear
(
total_x
,
weight
,
bias
)
if
ctx
.
compute_weight_gradient
:
ctx
.
save_for_backward
(
x
,
weight
)
else
:
ctx
.
save_for_backward
(
weight
)
return
output
if
not
return_residual
else
(
output
,
x
)
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
,
*
args
):
grad_output
=
grad_output
.
contiguous
()
if
ctx
.
return_residual
:
(
grad_input
,)
=
args
grad_input
=
grad_input
.
contiguous
()
process_group
=
ctx
.
process_group
sequence_parallel
=
ctx
.
sequence_parallel
if
ctx
.
compute_weight_gradient
:
x
,
weight
=
ctx
.
saved_tensors
if
process_group
is
not
None
and
sequence_parallel
:
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
else
:
total_x
=
x
else
:
(
weight
,)
=
ctx
.
saved_tensors
total_x
=
None
batch_shape
=
grad_output
.
shape
[:
-
1
]
batch_dim
=
batch_shape
.
numel
()
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
if
ctx
.
needs_input_grad
[
0
]:
if
not
ctx
.
return_residual
:
grad_input
=
F
.
linear
(
grad_output
,
weight
.
t
())
else
:
grad_input
=
torch
.
addmm
(
grad_input
.
reshape
(
batch_dim
,
grad_input
.
shape
[
-
1
]),
grad_output
,
weight
)
grad_input
=
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
if
process_group
is
not
None
:
reduce_fn
=
reduce_scatter_raw
if
sequence_parallel
else
all_reduce_raw
grad_input
,
handle_grad_input
=
reduce_fn
(
grad_input
,
process_group
,
async_op
=
True
)
else
:
grad_input
=
None
if
ctx
.
needs_input_grad
[
1
]:
assert
ctx
.
compute_weight_gradient
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
grad_weight
,
grad_bias
=
fused_dense_cuda
.
linear_bias_wgrad
(
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
grad_output
,
ctx
.
needs_input_grad
[
2
]
)
else
:
grad_weight
=
None
grad_bias
=
grad_output
if
ctx
.
needs_input_grad
[
2
]
else
None
if
process_group
is
not
None
and
ctx
.
needs_input_grad
[
0
]:
handle_grad_input
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
def
fused_dense_func
(
x
:
Tensor
,
weight
:
Tensor
,
bias
:
Optional
[
Tensor
]
=
None
,
return_residual
:
bool
=
False
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
sequence_parallel
:
bool
=
True
,
):
dtype_eligible
=
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()
)
if
x
.
is_cuda
and
weight
.
is_cuda
and
(
bias
is
None
or
bias
.
is_cuda
)
and
dtype_eligible
:
return
FusedDenseFunc
.
apply
(
x
,
weight
,
bias
,
return_residual
,
process_group
,
sequence_parallel
)
else
:
assert
process_group
is
None
out
=
F
.
linear
(
x
,
weight
,
bias
)
return
out
if
not
return_residual
else
(
out
,
x
)
class
FusedDense
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
return_residual
:
bool
=
False
,
device
=
None
,
dtype
=
None
,
)
->
None
:
super
().
__init__
(
in_features
,
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
self
.
return_residual
=
return_residual
def
forward
(
self
,
x
,
process_group
=
None
):
"""
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul.
"""
return
fused_dense_func
(
x
,
self
.
weight
,
self
.
bias
,
return_residual
=
self
.
return_residual
,
process_group
=
process_group
,
)
class
ColumnParallelLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
process_group
:
ProcessGroup
,
bias
:
bool
=
True
,
sequence_parallel
=
True
,
multiple_of
=
1
,
device
=
None
,
dtype
=
None
,
)
->
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
out_features
%
multiple_of
:
raise
ValueError
(
f
"out_features (
{
out_features
}
) must be a multiple of
{
multiple_of
}
"
)
multiple
=
out_features
//
multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div
=
multiple
//
world_size
mod
=
multiple
%
world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple
=
div
+
int
(
torch
.
distributed
.
get_rank
(
process_group
)
<
mod
)
super
().
__init__
(
in_features
,
local_multiple
*
multiple_of
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
def
forward
(
self
,
x
):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
return
fused_dense_func
(
x
,
self
.
weight
,
self
.
bias
,
process_group
=
self
.
process_group
,
sequence_parallel
=
self
.
sequence_parallel
,
)
class
RowParallelLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
process_group
:
ProcessGroup
,
bias
:
bool
=
True
,
sequence_parallel
=
True
,
multiple_of
=
1
,
device
=
None
,
dtype
=
None
,
)
->
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
rank
=
torch
.
distributed
.
get_rank
(
process_group
)
if
in_features
%
multiple_of
:
raise
ValueError
(
f
"in_features (
{
in_features
}
) must be a multiple of
{
multiple_of
}
"
)
multiple
=
in_features
//
multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div
=
multiple
//
world_size
mod
=
multiple
%
world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple
=
div
+
int
(
torch
.
distributed
.
get_rank
(
process_group
)
<
mod
)
# Only rank 0 will have bias
super
().
__init__
(
local_multiple
*
multiple_of
,
out_features
,
bias
=
bias
and
rank
==
0
,
device
=
device
,
dtype
=
dtype
,
)
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
def
forward
(
self
,
x
):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
a reduce_scatter of the result.
"""
out
=
fused_dense_func
(
x
,
self
.
weight
,
self
.
bias
)
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
return
reduce_fn
(
out
,
self
.
process_group
)
class
FusedMLPFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
activation
=
"gelu_approx"
,
save_pre_act
=
True
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
process_group
=
None
,
sequence_parallel
=
True
,
):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather of x before doing the matmul.
If sequence_parallel=False, then the input is already gathered.
checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out / relu_out in the bwd
2: recompute pre_act and gelu_out / relu_out in the bwd
"""
assert
-
1
<=
heuristic
<=
4
assert
activation
in
[
"gelu_approx"
,
"relu"
,
"sqrelu"
]
if
activation
==
"sqrelu"
:
assert
heuristic
==
-
1
if
not
save_pre_act
:
checkpoint_lvl
=
2
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
ctx
.
return_residual
=
return_residual
ctx
.
process_group
=
process_group
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
checkpoint_lvl
=
checkpoint_lvl
ctx
.
activation
=
activation
ctx
.
heuristic
=
heuristic
if
torch
.
is_autocast_enabled
():
x
=
x
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
x
=
x
.
contiguous
()
if
process_group
is
not
None
and
sequence_parallel
:
# We want to kick off the all_gather early, before weight dtype conversion
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
else
:
total_x
=
x
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
weight1
,
weight2
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
weight1
,
weight2
]]
bias1
=
bias1
.
to
(
dtype
=
dtype
)
if
bias1
is
not
None
else
None
bias2
=
bias2
.
to
(
dtype
=
dtype
)
if
bias2
is
not
None
else
None
weight1
=
weight1
.
contiguous
()
bias1
=
bias1
.
contiguous
()
if
bias1
is
not
None
else
None
weight2
=
weight2
.
contiguous
()
bias2
=
bias2
.
contiguous
()
if
bias2
is
not
None
else
None
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
batch_shape
,
n
=
total_x
.
shape
[:
-
1
],
total_x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if
min
(
batch_dim
,
n
,
*
weight1
.
shape
,
*
weight2
.
shape
)
>
65535
*
32
:
raise
RuntimeError
(
"fused_dense only supports matrix dims <= 2M"
)
if
heuristic
==
-
1
:
pre_act
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
"tanh"
)
if
activation
==
"gelu_approx"
else
(
sqrelu_fwd
if
activation
==
"sqrelu"
else
F
.
relu
)
)
with
torch
.
jit
.
fuser
(
"fuser2"
):
output1
=
activation_fn
(
pre_act
)
# This is before adding bias1
# pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
# with torch.jit.fuser('fuser2'):
# output1 = bias_gelu(pre_act, bias1)
else
:
is_gelu
=
activation
==
"gelu_approx"
output1
,
*
rest
=
fused_dense_cuda
.
linear_act_forward
(
total_x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
is_gelu
,
save_pre_act
,
heuristic
)
if
save_pre_act
:
pre_act
=
rest
[
0
]
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
if
checkpoint_lvl
==
0
or
(
checkpoint_lvl
==
1
and
activation
==
"relu"
):
# For RELU the pre_act is very small (just a bit-mask) so we just save it
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
pre_act
,
output1
)
elif
checkpoint_lvl
==
1
:
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
pre_act
)
elif
checkpoint_lvl
==
2
:
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
bias1
)
output2
=
output2
.
reshape
(
*
batch_shape
,
output2
.
shape
[
-
1
])
return
output2
if
not
return_residual
else
(
output2
,
x
)
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
,
*
args
):
grad_output
=
grad_output
.
contiguous
()
checkpoint_lvl
=
ctx
.
checkpoint_lvl
activation
=
ctx
.
activation
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
"tanh"
)
if
activation
==
"gelu_approx"
else
(
sqrelu_fwd
if
activation
==
"sqrelu"
else
F
.
relu
)
)
if
ctx
.
return_residual
:
(
grad_input
,)
=
args
grad_input
=
grad_input
.
contiguous
()
process_group
=
ctx
.
process_group
sequence_parallel
=
ctx
.
sequence_parallel
x
,
weight1
,
weight2
,
*
rest
=
ctx
.
saved_tensors
if
process_group
is
None
or
not
sequence_parallel
:
total_x
=
x
batch_shape
=
grad_output
.
shape
[:
-
1
]
batch_dim
=
batch_shape
.
numel
()
if
checkpoint_lvl
in
[
0
,
1
]:
if
process_group
is
not
None
and
sequence_parallel
:
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
if
checkpoint_lvl
==
0
or
(
checkpoint_lvl
==
1
and
activation
==
"relu"
):
pre_act
,
output1
=
rest
elif
checkpoint_lvl
==
1
:
(
pre_act
,)
=
rest
with
torch
.
jit
.
fuser
(
"fuser2"
):
output1
=
activation_fn
(
pre_act
)
elif
checkpoint_lvl
==
2
:
(
bias1
,)
=
rest
if
process_group
is
not
None
and
sequence_parallel
:
total_x
,
_
=
all_gather_raw
(
x
,
process_group
)
if
ctx
.
heuristic
==
-
1
:
pre_act
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
with
torch
.
jit
.
fuser
(
"fuser2"
):
output1
=
activation_fn
(
pre_act
)
else
:
output1
,
pre_act
=
fused_dense_cuda
.
linear_act_forward
(
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
weight1
,
bias1
,
activation
==
"gelu_approx"
,
True
,
ctx
.
heuristic
,
)
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
output1
=
output1
.
reshape
(
batch_dim
,
output1
.
shape
[
-
1
])
pre_act
=
pre_act
.
reshape
(
batch_dim
,
pre_act
.
shape
[
-
1
])
if
ctx
.
needs_input_grad
[
3
]:
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
,
ctx
.
needs_input_grad
[
4
]
)
else
:
grad_weight2
=
None
grad_bias2
=
grad_output
if
ctx
.
needs_input_grad
[
4
]
else
None
if
ctx
.
heuristic
==
-
1
:
# grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
grad_output1
=
F
.
linear
(
grad_output
,
weight2
.
t
())
activation_grad_fn
=
(
gelu_bwd
if
activation
==
"gelu_approx"
else
(
sqrelu_bwd
if
activation
==
"sqrelu"
else
relu_bwd
)
)
with
torch
.
jit
.
fuser
(
"fuser2"
):
grad_pre_act
=
activation_grad_fn
(
grad_output1
,
pre_act
)
else
:
# The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
# just compute gelu/relu grad
grad_pre_act
,
grad_bias1
=
fused_dense_cuda
.
bias_act_linear_dgrad_bgrad
(
weight2
,
grad_output
,
pre_act
,
activation
==
"gelu_approx"
,
ctx
.
heuristic
)
if
not
ctx
.
needs_input_grad
[
2
]:
grad_bias1
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
not
ctx
.
return_residual
:
grad_input
=
F
.
linear
(
grad_pre_act
,
weight1
.
t
())
else
:
grad_input
=
torch
.
addmm
(
grad_input
.
reshape
(
batch_dim
,
grad_input
.
shape
[
-
1
]),
grad_pre_act
,
weight1
)
grad_input
=
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
if
process_group
is
not
None
:
reduce_fn
=
reduce_scatter_raw
if
sequence_parallel
else
all_reduce_raw
grad_input
,
handle_grad_input
=
reduce_fn
(
grad_input
,
process_group
,
async_op
=
True
)
else
:
grad_input
=
None
if
ctx
.
heuristic
==
-
1
:
if
ctx
.
needs_input_grad
[
1
]:
if
process_group
is
not
None
and
sequence_parallel
and
checkpoint_lvl
!=
2
:
handle_x
.
wait
()
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_wgrad
(
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
grad_pre_act
,
ctx
.
needs_input_grad
[
2
],
)
else
:
grad_weight1
=
None
grad_bias1
=
grad_pre_act
if
ctx
.
needs_input_grad
[
2
]
else
None
else
:
if
ctx
.
needs_input_grad
[
1
]:
if
process_group
is
not
None
and
sequence_parallel
and
checkpoint_lvl
!=
2
:
handle_x
.
wait
()
grad_weight1
=
F
.
linear
(
grad_pre_act
.
t
(),
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]).
t
()
)
else
:
grad_weight1
=
None
if
process_group
is
not
None
and
ctx
.
needs_input_grad
[
0
]:
handle_grad_input
.
wait
()
return
(
grad_input
,
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
def
fused_mlp_func
(
x
:
Tensor
,
weight1
:
Tensor
,
weight2
:
Tensor
,
bias1
:
Optional
[
Tensor
]
=
None
,
bias2
:
Optional
[
Tensor
]
=
None
,
activation
:
str
=
"gelu_approx"
,
save_pre_act
:
bool
=
True
,
return_residual
:
bool
=
False
,
checkpoint_lvl
:
int
=
0
,
heuristic
:
int
=
0
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
sequence_parallel
:
bool
=
True
,
):
assert
activation
in
[
"gelu_approx"
,
"relu"
,
"sqrelu"
]
dtype_eligible
=
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()
)
# If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
dim_eligible
=
not
save_pre_act
or
(
x
.
shape
[
-
1
]
%
(
128
if
activation
==
"relu"
else
8
)
==
0
)
if
(
x
.
is_cuda
and
weight1
.
is_cuda
and
weight2
.
is_cuda
and
(
bias1
is
None
or
bias1
.
is_cuda
)
and
(
bias2
is
None
or
bias2
.
is_cuda
)
and
dtype_eligible
and
dim_eligible
):
return
FusedMLPFunc
.
apply
(
x
,
weight1
,
bias1
,
weight2
,
bias2
,
activation
,
save_pre_act
,
return_residual
,
checkpoint_lvl
,
heuristic
,
process_group
,
sequence_parallel
,
)
else
:
assert
process_group
is
None
pre_act
=
F
.
linear
(
x
,
weight1
,
bias1
)
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
"tanh"
)
if
activation
==
"gelu_approx"
else
partial
(
F
.
relu
,
inplace
=
True
)
)
output1
=
activation_fn
(
pre_act
)
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
return
output2
if
not
return_residual
else
(
output2
,
x
)
class
FusedMLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
bias1
=
True
,
bias2
=
True
,
activation
=
"gelu_approx"
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
"auto"
,
device
=
None
,
dtype
=
None
,
):
"""
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
Finally we do a reduce_scatter of the output.
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute pre_act and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
'auto': heuristic will be picked automatically:
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation
is slower than the unfused version.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
activation
in
[
"gelu_approx"
,
"relu"
,
"sqrelu"
]
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
*
4
self
.
activation
=
activation
self
.
return_residual
=
return_residual
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
heuristic
=
heuristic
if
activation
!=
"sqrelu"
else
-
1
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
process_group
=
None
):
dtype
=
x
.
dtype
if
not
torch
.
is_autocast_enabled
()
else
torch
.
get_autocast_gpu_dtype
()
if
self
.
heuristic
==
"auto"
:
if
self
.
activation
==
"gelu_approx"
:
if
torch
.
cuda
.
get_device_capability
(
"cuda"
)
==
(
9
,
0
):
heuristic
=
-
1
else
:
cuda_ver
=
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
"."
)))
heuristic
=
0
if
cuda_ver
>=
(
11
,
8
)
else
(
1
if
dtype
==
torch
.
float16
else
-
1
)
else
:
heuristic
=
0
else
:
heuristic
=
self
.
heuristic
out
=
fused_mlp_func
(
x
,
self
.
fc1
.
weight
,
self
.
fc2
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
bias
,
activation
=
self
.
activation
,
save_pre_act
=
self
.
training
,
return_residual
=
self
.
return_residual
,
checkpoint_lvl
=
self
.
checkpoint_lvl
,
heuristic
=
heuristic
,
process_group
=
process_group
,
)
if
self
.
return_residual
:
out
,
x
=
out
if
process_group
is
not
None
:
out
=
reduce_scatter
(
out
,
process_group
)
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
class
ParallelFusedMLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
"gelu_approx"
,
process_group
:
ProcessGroup
=
None
,
bias1
=
True
,
bias2
=
True
,
sequence_parallel
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
"auto"
,
device
=
None
,
dtype
=
None
,
):
"""
process_group is required. We're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
Finally we do a reduce_scatter of the output.
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute pre_act and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
'auto': heuristic will be picked automatically:
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
activation
in
[
"gelu_approx"
,
"relu"
,
"sqrelu"
]
assert
process_group
is
not
None
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
*
4
self
.
activation
=
activation
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
heuristic
=
heuristic
if
activation
!=
"sqrelu"
else
-
1
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
process_group
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
dtype
=
x
.
dtype
if
not
torch
.
is_autocast_enabled
()
else
torch
.
get_autocast_gpu_dtype
()
if
self
.
heuristic
==
"auto"
:
if
self
.
activation
==
"gelu_approx"
:
cuda_ver
=
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
"."
)))
heuristic
=
0
if
cuda_ver
>=
(
11
,
8
)
else
(
1
if
dtype
==
torch
.
float16
else
-
1
)
else
:
heuristic
=
0
else
:
heuristic
=
self
.
heuristic
out
=
fused_mlp_func
(
x
,
self
.
fc1
.
weight
,
self
.
fc2
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
bias
,
activation
=
self
.
activation
,
save_pre_act
=
self
.
training
,
checkpoint_lvl
=
self
.
checkpoint_lvl
,
heuristic
=
heuristic
,
process_group
=
self
.
process_group
,
sequence_parallel
=
self
.
sequence_parallel
,
)
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
return
reduce_fn
(
out
,
self
.
process_group
)
vllm_flash_attn/ops/layer_norm.py
deleted
100644 → 0
View file @
6ac8e63a
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import
dropout_layer_norm
import
torch
from
torch.nn
import
init
def
maybe_align
(
x
,
alignment_in_bytes
=
16
):
"""Assume that x already has last dim divisible by alignment_in_bytes"""
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
return
x
if
x
.
data_ptr
()
%
alignment_in_bytes
==
0
else
x
.
clone
()
def
_dropout_add_layer_norm_forward
(
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size
=
gamma
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
residualmat
,
gamma
,
beta
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
epsilon
,
1.0
,
0
,
None
,
residual_in_fp32
,
is_rms_norm
,
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return
zmat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask
,
mu
,
rsigma
def
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
"""
hidden_size
=
gamma
.
numel
()
xmat
=
x
.
view
((
-
1
,
hidden_size
))
dzmat
=
dz
.
view
(
xmat
.
shape
)
dxmat
=
dx
.
view
(
xmat
.
shape
)
if
dx
is
not
None
else
None
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
if
x0
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
if
colscale
is
not
None
:
assert
x0
is
not
None
,
"x0 is required to compute the gradient of colscale"
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
1.0
,
0
,
has_residual
,
is_rms_norm
,
)
# dresidualmat is None if not has_residual
if
colscale
is
None
:
return
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
else
:
dcolscale
=
rest
[
0
]
return
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
dcolscale
def
_dropout_add_layer_norm_subset_forward
(
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size
=
gamma
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
x0_subset
=
x0_subset
.
view
(
-
1
)
if
x0_subset
is
not
None
else
None
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
residualmat
,
gamma
,
beta
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
None
,
residual_in_fp32
,
is_rms_norm
,
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return
zmat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask
,
mu
,
rsigma
def
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
"""
hidden_size
=
gamma
.
numel
()
xmat
=
x
.
view
((
-
1
,
hidden_size
))
dzmat
=
dz
.
view
(
-
1
,
hidden_size
)
dxmat
=
dx
.
view
(
xmat
.
shape
)
if
dx
is
not
None
else
None
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
if
x0
is
not
None
else
None
x0_subset
=
x0_subset
.
view
(
-
1
)
if
x0_subset
is
not
None
else
None
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
if
colscale
is
not
None
:
assert
x0
is
not
None
,
"x0 is required to compute the gradient of colscale"
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
,
is_rms_norm
,
)
# dresidualmat is None if not has_residual
if
colscale
is
None
:
return
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
else
:
dcolscale
=
rest
[
0
]
return
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
dcolscale
def
_dropout_add_layer_norm_parallel_residual_forward
(
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size
=
gamma0
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
x1mat
=
x1
.
view
((
-
1
,
hidden_size
))
if
x1
is
not
None
else
None
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
(
z0mat
,
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
,
)
=
dropout_layer_norm
.
dropout_add_ln_parallel_residual_fwd
(
x0mat
,
x1mat
,
residualmat
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
None
,
residual_in_fp32
,
is_rms_norm
,
)
# dmask0 and dmask1 are None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return
z0mat
,
z1mat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask0
,
dmask1
,
mu
,
rsigma
def
_dropout_add_layer_norm_parallel_residual_backward
(
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
"""
hidden_size
=
gamma0
.
numel
()
xmat
=
x
.
view
((
-
1
,
hidden_size
))
dz0mat
=
dz0
.
view
(
xmat
.
shape
)
dz1mat
=
dz1
.
view
(
xmat
.
shape
)
if
dz1
is
not
None
else
None
dxmat
=
dx
.
view
(
xmat
.
shape
)
if
dx
is
not
None
else
None
(
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
,
*
rest
,
)
=
dropout_layer_norm
.
dropout_add_ln_parallel_residual_bwd
(
dz0mat
,
dz1mat
,
dxmat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
,
)
# dresidualmat is None if not has_residual
return
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
class
DropoutAddLayerNormFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
,
):
x0
=
maybe_align
(
x0
.
contiguous
(),
16
)
residual
=
maybe_align
(
residual
.
contiguous
(),
16
)
if
residual
is
not
None
else
None
gamma
=
maybe_align
(
gamma
.
contiguous
(),
16
)
beta
=
maybe_align
(
beta
.
contiguous
(),
16
)
if
beta
is
not
None
else
None
rowscale
=
maybe_align
(
rowscale
.
contiguous
(),
16
)
if
rowscale
is
not
None
else
None
colscale
=
maybe_align
(
colscale
.
contiguous
(),
16
)
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_forward
(
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
,
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
x0_saved
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
)
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
has_residual
=
residual
is
not
None
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_beta
=
beta
is
not
None
if
not
return_dmask
:
return
(
zmat
.
view
(
x0
.
shape
)
if
not
prenorm
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
))
)
else
:
dmask
=
(
dmask
.
view
(
x0
.
shape
)
if
dropout_p
>
0.0
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
)
)
ctx
.
mark_non_differentiable
(
dmask
)
return
(
(
zmat
.
view
(
x0
.
shape
),
dmask
)
if
not
prenorm
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
),
dmask
)
)
@
staticmethod
def
backward
(
ctx
,
dz
,
*
args
):
# assert dz.is_contiguous()
dz
=
maybe_align
(
dz
.
contiguous
(),
16
)
# this happens!
dx
=
maybe_align
(
args
[
0
].
contiguous
(),
16
)
if
ctx
.
prenorm
else
None
x
,
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
=
ctx
.
saved_tensors
# x0 is None if colscale is None
dropout_p
=
ctx
.
dropout_p
has_residual
=
ctx
.
has_residual
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
,
ctx
.
is_rms_norm
,
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
return
(
dx0
,
dresidual
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
,
)
class
DropoutAddLayerNormSubsetFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
,
):
x0
=
maybe_align
(
x0
.
contiguous
(),
16
)
residual
=
maybe_align
(
residual
.
contiguous
(),
16
)
if
residual
is
not
None
else
None
gamma
=
maybe_align
(
gamma
.
contiguous
(),
16
)
beta
=
maybe_align
(
beta
.
contiguous
(),
16
)
if
beta
is
not
None
else
None
colscale
=
maybe_align
(
colscale
.
contiguous
(),
16
)
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_subset_forward
(
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
is_rms_norm
,
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
x_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
ctx
.
save_for_backward
(
xmat
.
view
(
x_shape
),
x0_saved
,
dmask
,
gamma
,
mu
,
rsigma
,
colscale
,
x0_subset
,
out_subset
)
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
rowscale_const
=
rowscale_const
ctx
.
x0_numrows
=
x0
.
shape
[:
-
1
].
numel
()
ctx
.
has_residual
=
residual
is
not
None
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_beta
=
beta
is
not
None
z_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
if
not
return_dmask
:
return
zmat
.
view
(
z_shape
)
if
not
prenorm
else
(
zmat
.
view
(
z_shape
),
xmat
.
view
(
x0
.
shape
))
else
:
z
=
zmat
.
view
(
z_shape
)
dmask
=
(
dmask
.
view
(
x0
.
shape
)
if
dropout_p
>
0.0
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
)
)
ctx
.
mark_non_differentiable
(
dmask
)
return
(
z
,
dmask
)
if
not
prenorm
else
(
z
,
xmat
.
view
(
x_shape
),
dmask
)
@
staticmethod
def
backward
(
ctx
,
dz
,
*
args
):
# assert dz.is_contiguous()
dz
=
maybe_align
(
dz
.
contiguous
(),
16
)
# this happens!
dx
=
maybe_align
(
args
[
0
].
contiguous
(),
16
)
if
ctx
.
prenorm
else
None
x
,
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
colscale
,
x0_subset
,
out_subset
=
ctx
.
saved_tensors
# x0 is None if colscale is None
dropout_p
=
ctx
.
dropout_p
has_residual
=
ctx
.
has_residual
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
ctx
.
rowscale_const
,
ctx
.
x0_numrows
,
has_residual
,
ctx
.
is_rms_norm
,
)
dx0
=
dx0mat
.
view
(
-
1
,
*
x
.
shape
[
1
:])
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
return
(
dx0
,
dresidual
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
class
DropoutAddLayerNormParallelResidualFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
,
):
x0
=
maybe_align
(
x0
.
contiguous
(),
16
)
x1
=
maybe_align
(
x1
.
contiguous
(),
16
)
if
x1
is
not
None
else
None
residual
=
maybe_align
(
residual
.
contiguous
(),
16
)
if
residual
is
not
None
else
None
gamma0
=
maybe_align
(
gamma0
.
contiguous
(),
16
)
beta0
=
maybe_align
(
beta0
.
contiguous
(),
16
)
if
beta0
is
not
None
else
None
gamma1
=
maybe_align
(
gamma1
.
contiguous
(),
16
)
if
gamma1
is
not
None
else
None
beta1
=
maybe_align
(
beta1
.
contiguous
(),
16
)
if
beta1
is
not
None
else
None
(
z0mat
,
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
,
)
=
_dropout_add_layer_norm_parallel_residual_forward
(
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
,
)
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
dmask0
,
dmask1
,
gamma0
,
gamma1
,
mu
,
rsigma
)
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
has_x1
=
x1
is
not
None
ctx
.
has_residual
=
residual
is
not
None
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_beta
=
beta0
is
not
None
z
=
(
z0mat
.
view
(
x0
.
shape
),
z1mat
.
view
(
x0
.
shape
)
if
z1mat
is
not
None
else
None
)
if
not
return_dmask
:
return
z
if
not
prenorm
else
(
*
z
,
xmat
.
view
(
x0
.
shape
))
else
:
dmask0
=
(
dmask0
.
view
(
x0
.
shape
)
if
dropout_p
>
0.0
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
)
)
dmask1
=
(
dmask1
.
view
(
x0
.
shape
)
if
dropout_p
>
0.0
and
x1
is
not
None
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
)
)
ctx
.
mark_non_differentiable
(
dmask0
)
ctx
.
mark_non_differentiable
(
dmask1
)
return
(
(
*
z
,
dmask0
,
dmask1
)
if
not
prenorm
else
(
*
z
,
xmat
.
view
(
x0
.
shape
),
dmask0
,
dmask1
)
)
@
staticmethod
def
backward
(
ctx
,
dz0
,
dz1
,
*
args
):
dz0
=
maybe_align
(
dz0
.
contiguous
(),
16
)
# this happens!
dz1
=
maybe_align
(
dz1
.
contiguous
(),
16
)
if
dz1
is
not
None
else
None
dx
=
maybe_align
(
args
[
0
].
contiguous
(),
16
)
if
ctx
.
prenorm
else
None
x
,
dmask0
,
dmask1
,
gamma0
,
gamma1
,
mu
,
rsigma
=
ctx
.
saved_tensors
dropout_p
=
ctx
.
dropout_p
has_x1
=
ctx
.
has_x1
has_residual
=
ctx
.
has_residual
(
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
,
)
=
_dropout_add_layer_norm_parallel_residual_backward
(
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
ctx
.
is_rms_norm
,
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
return
(
dx0
,
dx1
,
dresidual
,
dgamma0
,
dbeta0
if
ctx
.
has_beta
else
None
,
dgamma1
,
dbeta1
if
ctx
.
has_beta
else
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
def
layer_norm
(
x
,
weight
,
bias
,
epsilon
):
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
bias
,
None
,
None
,
0.0
,
epsilon
,
False
)
def
dropout_add_layer_norm
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormFn
.
apply
(
x0
,
residual
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
,
)
def
dropout_add_layer_norm_subset
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormSubsetFn
.
apply
(
x0
,
residual
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
,
)
def
dropout_add_layer_norm_parallel_residual
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormParallelResidualFn
.
apply
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
,
)
class
DropoutAddLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
prenorm
=
prenorm
self
.
p
=
p
self
.
eps
=
eps
self
.
residual_in_fp32
=
residual_in_fp32
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
reset_parameters
()
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
x0
,
residual
=
None
):
return
dropout_add_layer_norm
(
x0
,
residual
,
self
.
weight
,
self
.
bias
,
self
.
p
if
self
.
training
else
0.0
,
self
.
eps
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
,
)
vllm_flash_attn/ops/rms_norm.py
deleted
100644 → 0
View file @
6ac8e63a
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import
torch
from
torch.nn
import
init
from
flash_attn.ops.layer_norm
import
(
DropoutAddLayerNormFn
,
DropoutAddLayerNormParallelResidualFn
,
DropoutAddLayerNormSubsetFn
,
)
def
rms_norm
(
x
,
weight
,
epsilon
):
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
None
,
None
,
None
,
0.0
,
epsilon
,
False
,
False
,
True
)
def
dropout_add_rms_norm
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormFn
.
apply
(
x0
,
residual
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
,
)
def
dropout_add_rms_norm_subset
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormSubsetFn
.
apply
(
x0
,
residual
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
,
)
def
dropout_add_rms_norm_parallel_residual
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormParallelResidualFn
.
apply
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
,
)
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
"bias"
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x
):
return
rms_norm
(
x
,
self
.
weight
,
self
.
eps
)
class
DropoutAddRMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
prenorm
=
prenorm
self
.
p
=
p
self
.
eps
=
eps
self
.
residual_in_fp32
=
residual_in_fp32
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
"bias"
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x0
,
residual
=
None
):
return
dropout_add_rms_norm
(
x0
,
residual
,
self
.
weight
,
None
,
self
.
p
if
self
.
training
else
0.0
,
self
.
eps
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
,
)
vllm_flash_attn/ops/triton/__init__.py
deleted
100644 → 0
View file @
6ac8e63a
vllm_flash_attn/ops/triton/cross_entropy.py
deleted
100644 → 0
View file @
6ac8e63a
# Copyright (c) 2023, Tri Dao.
from
typing
import
Tuple
,
Optional
,
Union
import
torch
from
einops
import
rearrange
import
triton
import
triton.language
as
tl
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if
"all_gather_into_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
all_gather_into_tensor
=
torch
.
distributed
.
_all_gather_base
@
triton
.
heuristics
(
{
"HAS_SMOOTHING"
:
lambda
args
:
args
[
"smoothing"
]
>
0.0
,
}
)
@
triton
.
jit
def
cross_entropy_fwd_kernel
(
loss_ptr
,
# data ptrs
lse_ptr
,
z_loss_ptr
,
logits_ptr
,
labels_ptr
,
smoothing
,
logit_scale
,
lse_square_scale
,
ignored_index
,
total_classes
,
class_start_idx
,
# Useful for tensor parallel when each rank only has a subset of classes
n_cols
,
# shapes
n_rows
,
logits_row_stride
,
# strides
BLOCK_SIZE
:
tl
.
constexpr
,
HAS_SMOOTHING
:
tl
.
constexpr
,
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
SPLIT
:
tl
.
constexpr
,
):
row_idx
=
tl
.
program_id
(
0
)
col_block_idx
=
tl
.
program_id
(
1
)
logits_ptr
=
logits_ptr
+
row_idx
*
logits_row_stride
.
to
(
tl
.
int64
)
col_offsets
=
col_block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
label_idx
=
tl
.
load
(
labels_ptr
+
row_idx
)
logits
=
tl
.
load
(
logits_ptr
+
col_offsets
,
mask
=
col_offsets
<
n_cols
,
other
=-
float
(
"inf"
)).
to
(
tl
.
float32
)
*
logit_scale
max_logits
=
tl
.
max
(
logits
,
0
)
if
HAS_SMOOTHING
:
sum_logits
=
tl
.
sum
(
tl
.
where
(
col_offsets
<
n_cols
,
logits
,
0.0
),
0
)
lse
=
tl
.
log
(
tl
.
sum
(
tl
.
exp
(
logits
-
max_logits
),
0
))
+
max_logits
tl
.
store
(
lse_ptr
+
col_block_idx
*
n_rows
+
row_idx
,
lse
)
if
label_idx
==
ignored_index
:
loss
=
0.0
z_loss
=
0.0
else
:
label_idx
-=
class_start_idx
if
label_idx
>=
col_block_idx
*
BLOCK_SIZE
and
label_idx
<
min
(
n_cols
,
(
col_block_idx
+
1
)
*
BLOCK_SIZE
):
logits_label
=
tl
.
load
(
logits_ptr
+
label_idx
)
*
logit_scale
if
HAS_SMOOTHING
:
loss
=
(
(
lse
if
not
SPLIT
else
0.0
)
-
smoothing
*
sum_logits
/
total_classes
-
(
1
-
smoothing
)
*
logits_label
)
else
:
loss
=
(
lse
if
not
SPLIT
else
0.0
)
-
logits_label
else
:
# If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
if
HAS_SMOOTHING
:
loss
=
smoothing
*
((
lse
if
not
SPLIT
else
0.0
)
-
sum_logits
/
total_classes
)
else
:
loss
=
0.0
if
not
SPLIT
:
z_loss
=
lse_square_scale
*
lse
*
lse
loss
+=
z_loss
else
:
z_loss
=
0.0
tl
.
store
(
loss_ptr
+
col_block_idx
*
n_rows
+
row_idx
,
loss
)
if
not
SPLIT
:
tl
.
store
(
z_loss_ptr
+
col_block_idx
*
n_rows
+
row_idx
,
z_loss
)
@
triton
.
heuristics
(
{
"HAS_SMOOTHING"
:
lambda
args
:
args
[
"smoothing"
]
>
0.0
,
}
)
@
triton
.
jit
def
cross_entropy_bwd_kernel
(
dlogits_ptr
,
# data ptrs
dloss_ptr
,
logits_ptr
,
lse_ptr
,
labels_ptr
,
smoothing
,
logit_scale
,
lse_square_scale
,
ignored_index
,
total_classes
,
class_start_idx
,
# Useful for tensor parallel when each rank only has a subset of classes
n_cols
,
# shapes
logits_row_stride
,
# strides
dlogits_row_stride
,
dloss_row_stride
,
BLOCK_SIZE
:
tl
.
constexpr
,
HAS_SMOOTHING
:
tl
.
constexpr
,
):
row_idx
=
tl
.
program_id
(
0
)
col_block_idx
=
tl
.
program_id
(
1
)
logits_ptr
=
logits_ptr
+
row_idx
*
logits_row_stride
.
to
(
tl
.
int64
)
dlogits_ptr
=
dlogits_ptr
+
row_idx
*
dlogits_row_stride
.
to
(
tl
.
int64
)
col_offsets
=
col_block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
label_idx
=
tl
.
load
(
labels_ptr
+
row_idx
)
if
label_idx
!=
ignored_index
:
dloss
=
tl
.
load
(
dloss_ptr
+
row_idx
*
dloss_row_stride
)
else
:
dloss
=
0.0
logits
=
tl
.
load
(
logits_ptr
+
col_offsets
,
mask
=
col_offsets
<
n_cols
,
other
=-
float
(
"inf"
)).
to
(
tl
.
float32
)
*
logit_scale
lse
=
tl
.
load
(
lse_ptr
+
row_idx
)
probs
=
tl
.
exp
(
logits
-
lse
)
probs
+=
2.0
*
lse_square_scale
*
lse
*
probs
label_idx
-=
class_start_idx
if
HAS_SMOOTHING
:
smooth_positive
=
1.0
-
smoothing
smooth_negative
=
smoothing
/
total_classes
probs
=
tl
.
where
(
col_offsets
==
label_idx
,
probs
-
(
1
-
smoothing
),
probs
)
-
smooth_negative
else
:
probs
=
tl
.
where
(
col_offsets
==
label_idx
,
probs
-
1.0
,
probs
)
tl
.
store
(
dlogits_ptr
+
col_offsets
,
(
dloss
*
logit_scale
)
*
probs
,
mask
=
col_offsets
<
n_cols
)
class
CrossEntropyLoss
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
logits
,
labels
,
smoothing
=
0.0
,
logit_scale
=
1.0
,
lse_square_scale
=
0.0
,
ignored_index
=-
100
,
inplace_backward
=
False
,
process_group
=
None
,
):
n_rows
,
n_cols
=
logits
.
shape
assert
labels
.
shape
==
(
n_rows
,)
world_size
=
1
if
process_group
is
None
else
torch
.
distributed
.
get_world_size
(
process_group
)
total_classes
=
world_size
*
n_cols
rank
=
0
if
process_group
is
None
else
torch
.
distributed
.
get_rank
(
process_group
)
class_start_idx
=
rank
*
n_cols
if
logits
.
stride
(
-
1
)
!=
1
:
logits
=
logits
.
contiguous
()
# Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
MAX_BLOCK_SIZE
=
64
*
1024
BLOCK_SIZE
=
min
(
triton
.
next_power_of_2
(
n_cols
),
MAX_BLOCK_SIZE
)
num_warps
=
(
4
if
BLOCK_SIZE
<
2048
else
(
8
if
BLOCK_SIZE
<
8192
else
(
16
if
BLOCK_SIZE
<
128
*
1024
else
32
))
)
# We may split the lse computation across multiple blocks, then do a reduction
# lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
# where having just one thread block processing more than 64k elements is slow.
split
=
world_size
>
1
or
n_cols
>
MAX_BLOCK_SIZE
n_splits
=
(
n_cols
+
BLOCK_SIZE
-
1
)
//
BLOCK_SIZE
loss_shape
=
(
n_splits
,
n_rows
)
if
n_splits
>
1
else
(
n_rows
,)
losses
=
torch
.
empty
(
*
loss_shape
,
dtype
=
torch
.
float
,
device
=
logits
.
device
)
lse
=
torch
.
empty
(
*
loss_shape
,
dtype
=
torch
.
float
,
device
=
logits
.
device
)
z_losses
=
torch
.
empty
(
*
loss_shape
,
dtype
=
torch
.
float
,
device
=
logits
.
device
)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with
torch
.
cuda
.
device
(
logits
.
device
.
index
):
cross_entropy_fwd_kernel
[(
n_rows
,
n_splits
)](
losses
,
# data ptrs
lse
,
z_losses
,
logits
,
labels
,
smoothing
,
logit_scale
,
lse_square_scale
,
ignored_index
,
total_classes
,
class_start_idx
,
n_cols
,
# shapes
n_rows
,
logits
.
stride
(
0
),
# strides
BLOCK_SIZE
=
BLOCK_SIZE
,
# constants
num_warps
=
num_warps
,
SPLIT
=
split
,
)
if
split
:
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
# - predicted logit, and 0 otherwise.
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
# For labels not in the vocab of this partition, losses contains
# -0.1 * sum logit / total_classes.
if
n_splits
>
1
:
lse
=
torch
.
logsumexp
(
lse
,
dim
=
0
)
losses
=
losses
.
sum
(
dim
=
0
)
if
world_size
>
1
:
lse_allgather
=
torch
.
empty
(
world_size
,
n_rows
,
dtype
=
lse
.
dtype
,
device
=
lse
.
device
)
torch
.
distributed
.
all_gather_into_tensor
(
lse_allgather
,
lse
,
group
=
process_group
)
handle_losses
=
torch
.
distributed
.
all_reduce
(
losses
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
,
async_op
=
True
)
lse
=
torch
.
logsumexp
(
lse_allgather
,
dim
=
0
)
handle_losses
.
wait
()
# After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
# we just have to add the (global) lse.
# If there's smoothing=0.1, the total losses are
# -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
# Again, we just have to add the (global) lse.
losses
+=
lse
if
lse_square_scale
!=
0.0
:
z_losses
=
lse_square_scale
*
lse
.
square
()
z_losses
.
masked_fill_
(
labels
==
ignored_index
,
0.0
)
losses
+=
z_losses
else
:
z_losses
=
torch
.
zeros_like
(
losses
)
losses
.
masked_fill_
(
labels
==
ignored_index
,
0.0
)
ctx
.
save_for_backward
(
logits
,
lse
,
labels
)
ctx
.
mark_non_differentiable
(
z_losses
)
ctx
.
smoothing
=
smoothing
ctx
.
logit_scale
=
logit_scale
ctx
.
lse_square_scale
=
lse_square_scale
ctx
.
ignored_index
=
ignored_index
ctx
.
total_classes
=
total_classes
ctx
.
class_start_idx
=
class_start_idx
ctx
.
inplace_backward
=
inplace_backward
return
losses
,
z_losses
@
staticmethod
def
backward
(
ctx
,
grad_losses
,
grad_z_losses
):
del
grad_z_losses
# z_losses are only for logging.
logits
,
lse
,
labels
=
ctx
.
saved_tensors
dlogits
=
logits
if
ctx
.
inplace_backward
else
torch
.
empty_like
(
logits
)
n_rows
,
n_cols
=
logits
.
shape
BLOCK_SIZE
=
min
(
triton
.
next_power_of_2
(
n_cols
),
4
*
1024
)
num_warps
=
4
if
BLOCK_SIZE
<
2048
else
(
8
if
BLOCK_SIZE
<
8192
else
16
)
grid
=
lambda
META
:
(
n_rows
,
triton
.
cdiv
(
n_cols
,
META
[
"BLOCK_SIZE"
]))
# noqa
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with
torch
.
cuda
.
device
(
logits
.
device
.
index
):
cross_entropy_bwd_kernel
[
grid
](
dlogits
,
# data ptrs
grad_losses
,
logits
,
lse
,
labels
,
ctx
.
smoothing
,
ctx
.
logit_scale
,
ctx
.
lse_square_scale
,
ctx
.
ignored_index
,
ctx
.
total_classes
,
ctx
.
class_start_idx
,
n_cols
,
# shapes
logits
.
stride
(
0
),
# strides
dlogits
.
stride
(
0
),
grad_losses
.
stride
(
0
),
BLOCK_SIZE
=
BLOCK_SIZE
,
# constants
num_warps
=
num_warps
,
)
return
dlogits
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
cross_entropy_loss
(
logits
:
torch
.
Tensor
,
labels
:
torch
.
Tensor
,
label_smoothing
:
float
=
0.0
,
logit_scale
:
float
=
1.0
,
lse_square_scale
:
float
=
0.0
,
ignored_index
=-
100
,
inplace_backward
:
bool
=
False
,
process_group
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Arguments:
logits: (batch, vocab_size)
labels: (batch,)
label_smoothing: float
logit_scale: float. Multiply logits by this scale before calculating the loss.
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss will be aggregated across processes.
Returns:
losses: (batch,), float
z_losses: (batch,), float
"""
return
CrossEntropyLoss
.
apply
(
logits
,
labels
,
label_smoothing
,
logit_scale
,
lse_square_scale
,
ignored_index
,
inplace_backward
,
process_group
,
)
vllm_flash_attn/ops/triton/k_activations.py
deleted
100644 → 0
View file @
6ac8e63a
# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
enum
import
Enum
from
typing
import
Optional
import
triton
import
triton.language
as
tl
_sqrt2pi
=
math
.
sqrt
(
2.0
/
math
.
pi
)
_sqrt1_2
=
math
.
sqrt
(
1.0
/
2
)
_gaussian_pdf_normalization
=
1.0
/
math
.
sqrt
(
2
*
math
.
pi
)
class
Activation
(
str
,
Enum
):
SquaredReLU
=
"squared_relu"
GeLU
=
"gelu"
GeLUApprox
=
"gelu_approx"
LeakyReLU
=
"leaky_relu"
ReLU
=
"relu"
def
get_triton_activation_kernel
(
activation
:
Optional
[
Activation
]):
return
(
{
Activation
.
ReLU
:
relu
,
Activation
.
LeakyReLU
:
leaky_relu
,
Activation
.
GeLU
:
gelu
,
Activation
.
GeLUApprox
:
gelu_approx
,
Activation
.
SquaredReLU
:
squared_relu
,
}[
activation
]
if
activation
else
None
)
def
get_triton_activation_bwd_kernel
(
activation
:
Optional
[
Activation
]):
return
(
{
Activation
.
ReLU
:
relu_grad
,
Activation
.
LeakyReLU
:
leaky_relu_grad
,
Activation
.
GeLU
:
gelu_grad
,
Activation
.
GeLUApprox
:
gelu_approx_grad
,
Activation
.
SquaredReLU
:
squared_relu_grad
,
}[
activation
]
if
activation
else
None
)
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
return
2
*
tl
.
sigmoid
(
2
*
x
)
-
1
@
triton
.
jit
def
cosh
(
x
):
exp_x
=
tl
.
exp
(
x
)
return
(
exp_x
+
1.0
/
exp_x
)
*
0.5
# a Triton implementation of the most used activations
# See for instance http://arxiv.org/abs/1606.08415 for an overview
# ReLU
@
triton
.
jit
def
relu
(
x
):
"""
ReLU_ activation function
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
"""
zero
=
0.0
return
tl
.
where
(
x
>=
0
,
x
,
zero
.
to
(
x
.
dtype
))
@
triton
.
jit
def
relu_grad
(
x
):
# ReLU is different from other activations
# in that it does not require the input to retrospectively compute its gradient
# here the input is the downstream gradient, and we return the upstream gradient directly
zero
=
0.0
one
=
1.0
return
tl
.
where
(
x
>=
0
,
one
.
to
(
x
.
dtype
),
zero
.
to
(
x
.
dtype
))
@
triton
.
jit
def
squared_relu
(
x
):
"""
Squared ReLU activation, as proposed in the Primer_ paper.
.. _Primer: https://arxiv.org/abs/2109.08668
"""
x_
=
relu
(
x
)
return
(
x_
*
x_
).
to
(
x
.
dtype
)
@
triton
.
jit
def
squared_relu_grad
(
x
):
return
tl
.
where
(
x
>=
0
,
2.0
*
x
,
0.0
)
# Leaky ReLU
@
triton
.
jit
def
leaky_relu
(
x
):
"""
LeakyReLU_ activation
.. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
"""
scale
=
0.01
+
0.0
scale
=
scale
.
to
(
x
.
dtype
)
return
tl
.
where
(
x
>=
0
,
x
,
scale
*
x
)
@
triton
.
jit
def
leaky_relu_grad
(
x
):
min_grad
=
0.01
max_grad
=
1
min_grad
=
min_grad
.
to
(
x
.
dtype
)
max_grad
=
max_grad
.
to
(
x
.
dtype
)
return
tl
.
where
(
x
>=
0
,
max_grad
,
min_grad
)
@
triton
.
jit
def
gelu
(
x
):
"""Gaussian Error Linear Unit (GELU)"""
return
x
*
0.5
*
(
1.0
+
tl
.
libdevice
.
erf
(
x
*
_sqrt1_2
))
@
triton
.
jit
def
gelu_grad
(
x
):
cdf
=
0.5
*
(
1.0
+
tl
.
libdevice
.
erf
(
x
*
_sqrt1_2
))
pdf
=
tl
.
exp
(
-
0.5
*
x
*
x
)
*
_gaussian_pdf_normalization
return
cdf
+
x
*
pdf
@
triton
.
jit
def
gelu_approx
(
x
):
"""
GeLU_ activation - Gaussian error linear unit, with tanh approximation
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
"""
return
0.5
*
x
*
(
1.0
+
tanh
(
_sqrt2pi
*
x
*
(
1.0
+
0.044715
*
x
*
x
)))
@
triton
.
jit
def
gelu_approx_grad
(
x
):
# CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out
=
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
return
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
vllm_flash_attn/ops/triton/layer_norm.py
deleted
100644 → 0
View file @
6ac8e63a
# Copyright (c) 2024, Tri Dao.
# Implement dropout + residual + layer_norm / rms_norm.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
import
math
import
torch
import
torch.nn.functional
as
F
from
torch.cuda.amp
import
custom_fwd
,
custom_bwd
import
triton
import
triton.language
as
tl
def
layer_norm_ref
(
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
dropout_mask
=
None
,
dropout_mask1
=
None
,
upcast
=
False
,
):
dtype
=
x
.
dtype
if
upcast
:
x
=
x
.
float
()
weight
=
weight
.
float
()
bias
=
bias
.
float
()
if
bias
is
not
None
else
None
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
x1
=
x1
.
float
()
if
x1
is
not
None
else
None
weight1
=
weight1
.
float
()
if
weight1
is
not
None
else
None
bias1
=
bias1
.
float
()
if
bias1
is
not
None
else
None
if
x1
is
not
None
:
assert
rowscale
is
None
,
"rowscale is not supported with parallel LayerNorm"
if
rowscale
is
not
None
:
x
=
x
*
rowscale
[...,
None
]
if
dropout_p
>
0.0
:
if
dropout_mask
is
not
None
:
x
=
x
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x
=
F
.
dropout
(
x
,
p
=
dropout_p
)
if
x1
is
not
None
:
if
dropout_mask1
is
not
None
:
x1
=
x1
.
masked_fill
(
~
dropout_mask1
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x1
=
F
.
dropout
(
x1
,
p
=
dropout_p
)
if
x1
is
not
None
:
x
=
x
+
x1
if
residual
is
not
None
:
x
=
(
x
+
residual
).
to
(
x
.
dtype
)
out
=
F
.
layer_norm
(
x
.
to
(
weight
.
dtype
),
x
.
shape
[
-
1
:],
weight
=
weight
,
bias
=
bias
,
eps
=
eps
).
to
(
dtype
)
if
weight1
is
None
:
return
out
if
not
prenorm
else
(
out
,
x
)
else
:
out1
=
F
.
layer_norm
(
x
.
to
(
weight1
.
dtype
),
x
.
shape
[
-
1
:],
weight
=
weight1
,
bias
=
bias1
,
eps
=
eps
).
to
(
dtype
)
return
(
out
,
out1
)
if
not
prenorm
else
(
out
,
out1
,
x
)
def
rms_norm_ref
(
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
dropout_mask
=
None
,
dropout_mask1
=
None
,
upcast
=
False
,
):
dtype
=
x
.
dtype
if
upcast
:
x
=
x
.
float
()
weight
=
weight
.
float
()
bias
=
bias
.
float
()
if
bias
is
not
None
else
None
residual
=
residual
.
float
()
if
residual
is
not
None
else
residual
x1
=
x1
.
float
()
if
x1
is
not
None
else
None
weight1
=
weight1
.
float
()
if
weight1
is
not
None
else
None
bias1
=
bias1
.
float
()
if
bias1
is
not
None
else
None
if
x1
is
not
None
:
assert
rowscale
is
None
,
"rowscale is not supported with parallel LayerNorm"
if
rowscale
is
not
None
:
x
=
x
*
rowscale
[...,
None
]
if
dropout_p
>
0.0
:
if
dropout_mask
is
not
None
:
x
=
x
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x
=
F
.
dropout
(
x
,
p
=
dropout_p
)
if
x1
is
not
None
:
if
dropout_mask1
is
not
None
:
x1
=
x1
.
masked_fill
(
~
dropout_mask1
,
0.0
)
/
(
1.0
-
dropout_p
)
else
:
x1
=
F
.
dropout
(
x1
,
p
=
dropout_p
)
if
x1
is
not
None
:
x
=
x
+
x1
if
residual
is
not
None
:
x
=
(
x
+
residual
).
to
(
x
.
dtype
)
rstd
=
1
/
torch
.
sqrt
((
x
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
out
=
((
x
*
rstd
*
weight
)
+
bias
if
bias
is
not
None
else
(
x
*
rstd
*
weight
)).
to
(
dtype
)
if
weight1
is
None
:
return
out
if
not
prenorm
else
(
out
,
x
)
else
:
out1
=
((
x
*
rstd
*
weight1
)
+
bias1
if
bias1
is
not
None
else
(
x
*
rstd
*
weight1
)).
to
(
dtype
)
return
(
out
,
out1
)
if
not
prenorm
else
(
out
,
out1
,
x
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
16
),
triton
.
Config
({},
num_warps
=
32
),
],
key
=
[
"N"
,
"HAS_RESIDUAL"
,
"STORE_RESIDUAL_OUT"
,
"IS_RMS_NORM"
,
"HAS_BIAS"
],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@
triton
.
heuristics
({
"HAS_X1"
:
lambda
args
:
args
[
"X1"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_W1"
:
lambda
args
:
args
[
"W1"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_B1"
:
lambda
args
:
args
[
"B1"
]
is
not
None
})
@
triton
.
jit
def
_layer_norm_fwd_1pass_kernel
(
X
,
# pointer to the input
Y
,
# pointer to the output
W
,
# pointer to the weights
B
,
# pointer to the biases
RESIDUAL
,
# pointer to the residual
X1
,
W1
,
B1
,
Y1
,
RESIDUAL_OUT
,
# pointer to the residual
ROWSCALE
,
SEEDS
,
# Dropout seeds for each row
DROPOUT_MASK
,
Mean
,
# pointer to the mean
Rstd
,
# pointer to the 1/std
stride_x_row
,
# how much to increase the pointer when moving by 1 row
stride_y_row
,
stride_res_row
,
stride_res_out_row
,
stride_x1_row
,
stride_y1_row
,
M
,
# number of rows in X
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
dropout_p
,
# Dropout probability
IS_RMS_NORM
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
HAS_RESIDUAL
:
tl
.
constexpr
,
STORE_RESIDUAL_OUT
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_DROPOUT
:
tl
.
constexpr
,
STORE_DROPOUT_MASK
:
tl
.
constexpr
,
HAS_ROWSCALE
:
tl
.
constexpr
,
HAS_X1
:
tl
.
constexpr
,
HAS_W1
:
tl
.
constexpr
,
HAS_B1
:
tl
.
constexpr
,
):
# Map the program id to the row of X and Y it should compute.
row
=
tl
.
program_id
(
0
)
X
+=
row
*
stride_x_row
Y
+=
row
*
stride_y_row
if
HAS_RESIDUAL
:
RESIDUAL
+=
row
*
stride_res_row
if
STORE_RESIDUAL_OUT
:
RESIDUAL_OUT
+=
row
*
stride_res_out_row
if
HAS_X1
:
X1
+=
row
*
stride_x1_row
if
HAS_W1
:
Y1
+=
row
*
stride_y1_row
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_ROWSCALE
:
rowscale
=
tl
.
load
(
ROWSCALE
+
row
).
to
(
tl
.
float32
)
x
*=
rowscale
if
HAS_DROPOUT
:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask
=
tl
.
rand
(
tl
.
load
(
SEEDS
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
x
=
tl
.
where
(
keep_mask
,
x
/
(
1.0
-
dropout_p
),
0.0
)
if
STORE_DROPOUT_MASK
:
tl
.
store
(
DROPOUT_MASK
+
row
*
N
+
cols
,
keep_mask
,
mask
=
cols
<
N
)
if
HAS_X1
:
x1
=
tl
.
load
(
X1
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_ROWSCALE
:
rowscale
=
tl
.
load
(
ROWSCALE
+
M
+
row
).
to
(
tl
.
float32
)
x1
*=
rowscale
if
HAS_DROPOUT
:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask
=
(
tl
.
rand
(
tl
.
load
(
SEEDS
+
M
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
)
x1
=
tl
.
where
(
keep_mask
,
x1
/
(
1.0
-
dropout_p
),
0.0
)
if
STORE_DROPOUT_MASK
:
tl
.
store
(
DROPOUT_MASK
+
(
M
+
row
)
*
N
+
cols
,
keep_mask
,
mask
=
cols
<
N
)
x
+=
x1
if
HAS_RESIDUAL
:
residual
=
tl
.
load
(
RESIDUAL
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
x
+=
residual
if
STORE_RESIDUAL_OUT
:
tl
.
store
(
RESIDUAL_OUT
+
cols
,
x
,
mask
=
cols
<
N
)
if
not
IS_RMS_NORM
:
mean
=
tl
.
sum
(
x
,
axis
=
0
)
/
N
tl
.
store
(
Mean
+
row
,
mean
)
xbar
=
tl
.
where
(
cols
<
N
,
x
-
mean
,
0.0
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
else
:
xbar
=
tl
.
where
(
cols
<
N
,
x
,
0.0
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
rstd
=
1
/
tl
.
sqrt
(
var
+
eps
)
tl
.
store
(
Rstd
+
row
,
rstd
)
# Normalize and apply linear transformation
mask
=
cols
<
N
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
HAS_BIAS
:
b
=
tl
.
load
(
B
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
x_hat
=
(
x
-
mean
)
*
rstd
if
not
IS_RMS_NORM
else
x
*
rstd
y
=
x_hat
*
w
+
b
if
HAS_BIAS
else
x_hat
*
w
# Write output
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
if
HAS_W1
:
w1
=
tl
.
load
(
W1
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
HAS_B1
:
b1
=
tl
.
load
(
B1
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
y1
=
x_hat
*
w1
+
b1
if
HAS_B1
else
x_hat
*
w1
tl
.
store
(
Y1
+
cols
,
y1
,
mask
=
mask
)
def
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
dropout_p
=
0.0
,
rowscale
=
None
,
out_dtype
=
None
,
residual_dtype
=
None
,
is_rms_norm
=
False
,
return_dropout_mask
=
False
,
):
if
residual
is
not
None
:
residual_dtype
=
residual
.
dtype
M
,
N
=
x
.
shape
assert
x
.
stride
(
-
1
)
==
1
if
residual
is
not
None
:
assert
residual
.
stride
(
-
1
)
==
1
assert
residual
.
shape
==
(
M
,
N
)
assert
weight
.
shape
==
(
N
,)
assert
weight
.
stride
(
-
1
)
==
1
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
if
x1
is
not
None
:
assert
x1
.
shape
==
x
.
shape
assert
rowscale
is
None
assert
x1
.
stride
(
-
1
)
==
1
if
weight1
is
not
None
:
assert
weight1
.
shape
==
(
N
,)
assert
weight1
.
stride
(
-
1
)
==
1
if
bias1
is
not
None
:
assert
bias1
.
shape
==
(
N
,)
assert
bias1
.
stride
(
-
1
)
==
1
if
rowscale
is
not
None
:
assert
rowscale
.
is_contiguous
()
assert
rowscale
.
shape
==
(
M
,)
# allocate output
y
=
torch
.
empty_like
(
x
,
dtype
=
x
.
dtype
if
out_dtype
is
None
else
out_dtype
)
assert
y
.
stride
(
-
1
)
==
1
if
weight1
is
not
None
:
y1
=
torch
.
empty_like
(
y
)
assert
y1
.
stride
(
-
1
)
==
1
else
:
y1
=
None
if
(
residual
is
not
None
or
(
residual_dtype
is
not
None
and
residual_dtype
!=
x
.
dtype
)
or
dropout_p
>
0.0
or
rowscale
is
not
None
or
x1
is
not
None
):
residual_out
=
torch
.
empty
(
M
,
N
,
device
=
x
.
device
,
dtype
=
residual_dtype
if
residual_dtype
is
not
None
else
x
.
dtype
)
assert
residual_out
.
stride
(
-
1
)
==
1
else
:
residual_out
=
None
mean
=
torch
.
empty
((
M
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
if
not
is_rms_norm
else
None
rstd
=
torch
.
empty
((
M
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
if
dropout_p
>
0.0
:
seeds
=
torch
.
randint
(
2
**
32
,
(
M
if
x1
is
None
else
2
*
M
,),
device
=
x
.
device
,
dtype
=
torch
.
int64
)
else
:
seeds
=
None
if
return_dropout_mask
and
dropout_p
>
0.0
:
dropout_mask
=
torch
.
empty
(
M
if
x1
is
None
else
2
*
M
,
N
,
device
=
x
.
device
,
dtype
=
torch
.
bool
)
else
:
dropout_mask
=
None
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
N
))
if
N
>
BLOCK_N
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_layer_norm_fwd_1pass_kernel
[(
M
,)](
x
,
y
,
weight
,
bias
,
residual
,
x1
,
weight1
,
bias1
,
y1
,
residual_out
,
rowscale
,
seeds
,
dropout_mask
,
mean
,
rstd
,
x
.
stride
(
0
),
y
.
stride
(
0
),
residual
.
stride
(
0
)
if
residual
is
not
None
else
0
,
residual_out
.
stride
(
0
)
if
residual_out
is
not
None
else
0
,
x1
.
stride
(
0
)
if
x1
is
not
None
else
0
,
y1
.
stride
(
0
)
if
y1
is
not
None
else
0
,
M
,
N
,
eps
,
dropout_p
,
is_rms_norm
,
BLOCK_N
,
residual
is
not
None
,
residual_out
is
not
None
,
bias
is
not
None
,
dropout_p
>
0.0
,
dropout_mask
is
not
None
,
rowscale
is
not
None
,
)
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
if
dropout_mask
is
not
None
and
x1
is
not
None
:
dropout_mask
,
dropout_mask1
=
dropout_mask
.
tensor_split
(
2
,
dim
=
0
)
else
:
dropout_mask1
=
None
return
(
y
,
y1
,
mean
,
rstd
,
residual_out
if
residual_out
is
not
None
else
x
,
seeds
,
dropout_mask
,
dropout_mask1
,
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
16
),
triton
.
Config
({},
num_warps
=
32
),
],
key
=
[
"N"
,
"HAS_DRESIDUAL"
,
"STORE_DRESIDUAL"
,
"IS_RMS_NORM"
,
"HAS_BIAS"
,
"HAS_DROPOUT"
],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
@
triton
.
heuristics
({
"HAS_ROWSCALE"
:
lambda
args
:
args
[
"ROWSCALE"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_DY1"
:
lambda
args
:
args
[
"DY1"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_DX1"
:
lambda
args
:
args
[
"DX1"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_B1"
:
lambda
args
:
args
[
"DB1"
]
is
not
None
})
@
triton
.
heuristics
({
"RECOMPUTE_OUTPUT"
:
lambda
args
:
args
[
"Y"
]
is
not
None
})
@
triton
.
jit
def
_layer_norm_bwd_kernel
(
X
,
# pointer to the input
W
,
# pointer to the weights
B
,
# pointer to the biases
Y
,
# pointer to the output to be recomputed
DY
,
# pointer to the output gradient
DX
,
# pointer to the input gradient
DW
,
# pointer to the partial sum of weights gradient
DB
,
# pointer to the partial sum of biases gradient
DRESIDUAL
,
W1
,
DY1
,
DX1
,
DW1
,
DB1
,
DRESIDUAL_IN
,
ROWSCALE
,
SEEDS
,
Mean
,
# pointer to the mean
Rstd
,
# pointer to the 1/std
stride_x_row
,
# how much to increase the pointer when moving by 1 row
stride_y_row
,
stride_dy_row
,
stride_dx_row
,
stride_dres_row
,
stride_dy1_row
,
stride_dx1_row
,
stride_dres_in_row
,
M
,
# number of rows in X
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
dropout_p
,
rows_per_program
,
IS_RMS_NORM
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
HAS_DRESIDUAL
:
tl
.
constexpr
,
STORE_DRESIDUAL
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_DROPOUT
:
tl
.
constexpr
,
HAS_ROWSCALE
:
tl
.
constexpr
,
HAS_DY1
:
tl
.
constexpr
,
HAS_DX1
:
tl
.
constexpr
,
HAS_B1
:
tl
.
constexpr
,
RECOMPUTE_OUTPUT
:
tl
.
constexpr
,
):
# Map the program id to the elements of X, DX, and DY it should compute.
row_block_id
=
tl
.
program_id
(
0
)
row_start
=
row_block_id
*
rows_per_program
# Do not early exit if row_start >= M, because we need to write DW and DB
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
mask
=
cols
<
N
X
+=
row_start
*
stride_x_row
if
HAS_DRESIDUAL
:
DRESIDUAL
+=
row_start
*
stride_dres_row
if
STORE_DRESIDUAL
:
DRESIDUAL_IN
+=
row_start
*
stride_dres_in_row
DY
+=
row_start
*
stride_dy_row
DX
+=
row_start
*
stride_dx_row
if
HAS_DY1
:
DY1
+=
row_start
*
stride_dy1_row
if
HAS_DX1
:
DX1
+=
row_start
*
stride_dx1_row
if
RECOMPUTE_OUTPUT
:
Y
+=
row_start
*
stride_y_row
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
RECOMPUTE_OUTPUT
and
HAS_BIAS
:
b
=
tl
.
load
(
B
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_DY1
:
w1
=
tl
.
load
(
W1
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
dw
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
if
HAS_BIAS
:
db
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
if
HAS_DY1
:
dw1
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
if
HAS_B1
:
db1
=
tl
.
zeros
((
BLOCK_N
,),
dtype
=
tl
.
float32
)
row_end
=
min
((
row_block_id
+
1
)
*
rows_per_program
,
M
)
for
row
in
range
(
row_start
,
row_end
):
# Load data to SRAM
x
=
tl
.
load
(
X
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
dy
=
tl
.
load
(
DY
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
if
HAS_DY1
:
dy1
=
tl
.
load
(
DY1
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
if
not
IS_RMS_NORM
:
mean
=
tl
.
load
(
Mean
+
row
)
rstd
=
tl
.
load
(
Rstd
+
row
)
# Compute dx
xhat
=
(
x
-
mean
)
*
rstd
if
not
IS_RMS_NORM
else
x
*
rstd
xhat
=
tl
.
where
(
mask
,
xhat
,
0.0
)
if
RECOMPUTE_OUTPUT
:
y
=
xhat
*
w
+
b
if
HAS_BIAS
else
xhat
*
w
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
wdy
=
w
*
dy
dw
+=
dy
*
xhat
if
HAS_BIAS
:
db
+=
dy
if
HAS_DY1
:
wdy
+=
w1
*
dy1
dw1
+=
dy1
*
xhat
if
HAS_B1
:
db1
+=
dy1
if
not
IS_RMS_NORM
:
c1
=
tl
.
sum
(
xhat
*
wdy
,
axis
=
0
)
/
N
c2
=
tl
.
sum
(
wdy
,
axis
=
0
)
/
N
dx
=
(
wdy
-
(
xhat
*
c1
+
c2
))
*
rstd
else
:
c1
=
tl
.
sum
(
xhat
*
wdy
,
axis
=
0
)
/
N
dx
=
(
wdy
-
xhat
*
c1
)
*
rstd
if
HAS_DRESIDUAL
:
dres
=
tl
.
load
(
DRESIDUAL
+
cols
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
float32
)
dx
+=
dres
# Write dx
if
STORE_DRESIDUAL
:
tl
.
store
(
DRESIDUAL_IN
+
cols
,
dx
,
mask
=
mask
)
if
HAS_DX1
:
if
HAS_DROPOUT
:
keep_mask
=
(
tl
.
rand
(
tl
.
load
(
SEEDS
+
M
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
)
dx1
=
tl
.
where
(
keep_mask
,
dx
/
(
1.0
-
dropout_p
),
0.0
)
else
:
dx1
=
dx
tl
.
store
(
DX1
+
cols
,
dx1
,
mask
=
mask
)
if
HAS_DROPOUT
:
keep_mask
=
tl
.
rand
(
tl
.
load
(
SEEDS
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
dx
=
tl
.
where
(
keep_mask
,
dx
/
(
1.0
-
dropout_p
),
0.0
)
if
HAS_ROWSCALE
:
rowscale
=
tl
.
load
(
ROWSCALE
+
row
).
to
(
tl
.
float32
)
dx
*=
rowscale
tl
.
store
(
DX
+
cols
,
dx
,
mask
=
mask
)
X
+=
stride_x_row
if
HAS_DRESIDUAL
:
DRESIDUAL
+=
stride_dres_row
if
STORE_DRESIDUAL
:
DRESIDUAL_IN
+=
stride_dres_in_row
if
RECOMPUTE_OUTPUT
:
Y
+=
stride_y_row
DY
+=
stride_dy_row
DX
+=
stride_dx_row
if
HAS_DY1
:
DY1
+=
stride_dy1_row
if
HAS_DX1
:
DX1
+=
stride_dx1_row
tl
.
store
(
DW
+
row_block_id
*
N
+
cols
,
dw
,
mask
=
mask
)
if
HAS_BIAS
:
tl
.
store
(
DB
+
row_block_id
*
N
+
cols
,
db
,
mask
=
mask
)
if
HAS_DY1
:
tl
.
store
(
DW1
+
row_block_id
*
N
+
cols
,
dw1
,
mask
=
mask
)
if
HAS_B1
:
tl
.
store
(
DB1
+
row_block_id
*
N
+
cols
,
db1
,
mask
=
mask
)
def
_layer_norm_bwd
(
dy
,
x
,
weight
,
bias
,
eps
,
mean
,
rstd
,
dresidual
=
None
,
dy1
=
None
,
weight1
=
None
,
bias1
=
None
,
seeds
=
None
,
dropout_p
=
0.0
,
rowscale
=
None
,
has_residual
=
False
,
has_x1
=
False
,
is_rms_norm
=
False
,
x_dtype
=
None
,
recompute_output
=
False
,
):
M
,
N
=
x
.
shape
assert
x
.
stride
(
-
1
)
==
1
assert
dy
.
stride
(
-
1
)
==
1
assert
dy
.
shape
==
(
M
,
N
)
if
dresidual
is
not
None
:
assert
dresidual
.
stride
(
-
1
)
==
1
assert
dresidual
.
shape
==
(
M
,
N
)
assert
weight
.
shape
==
(
N
,)
assert
weight
.
stride
(
-
1
)
==
1
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
if
dy1
is
not
None
:
assert
weight1
is
not
None
assert
dy1
.
shape
==
dy
.
shape
assert
dy1
.
stride
(
-
1
)
==
1
if
weight1
is
not
None
:
assert
weight1
.
shape
==
(
N
,)
assert
weight1
.
stride
(
-
1
)
==
1
if
bias1
is
not
None
:
assert
bias1
.
shape
==
(
N
,)
assert
bias1
.
stride
(
-
1
)
==
1
if
seeds
is
not
None
:
assert
seeds
.
is_contiguous
()
assert
seeds
.
shape
==
(
M
if
not
has_x1
else
M
*
2
,)
if
rowscale
is
not
None
:
assert
rowscale
.
is_contiguous
()
assert
rowscale
.
shape
==
(
M
,)
# allocate output
dx
=
(
torch
.
empty_like
(
x
)
if
x_dtype
is
None
else
torch
.
empty
(
M
,
N
,
dtype
=
x_dtype
,
device
=
x
.
device
)
)
dresidual_in
=
(
torch
.
empty_like
(
x
)
if
has_residual
and
(
dx
.
dtype
!=
x
.
dtype
or
dropout_p
>
0.0
or
rowscale
is
not
None
or
has_x1
)
else
None
)
dx1
=
torch
.
empty_like
(
dx
)
if
(
has_x1
and
dropout_p
>
0.0
)
else
None
y
=
torch
.
empty
(
M
,
N
,
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)
if
recompute_output
else
None
if
recompute_output
:
assert
weight1
is
None
,
"recompute_output is not supported with parallel LayerNorm"
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
N
))
if
N
>
BLOCK_N
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
sm_count
=
torch
.
cuda
.
get_device_properties
(
x
.
device
).
multi_processor_count
_dw
=
torch
.
empty
((
sm_count
,
N
),
dtype
=
torch
.
float32
,
device
=
weight
.
device
)
_db
=
(
torch
.
empty
((
sm_count
,
N
),
dtype
=
torch
.
float32
,
device
=
bias
.
device
)
if
bias
is
not
None
else
None
)
_dw1
=
torch
.
empty_like
(
_dw
)
if
weight1
is
not
None
else
None
_db1
=
torch
.
empty_like
(
_db
)
if
bias1
is
not
None
else
None
rows_per_program
=
math
.
ceil
(
M
/
sm_count
)
grid
=
(
sm_count
,)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_layer_norm_bwd_kernel
[
grid
](
x
,
weight
,
bias
,
y
,
dy
,
dx
,
_dw
,
_db
,
dresidual
,
weight1
,
dy1
,
dx1
,
_dw1
,
_db1
,
dresidual_in
,
rowscale
,
seeds
,
mean
,
rstd
,
x
.
stride
(
0
),
0
if
not
recompute_output
else
y
.
stride
(
0
),
dy
.
stride
(
0
),
dx
.
stride
(
0
),
dresidual
.
stride
(
0
)
if
dresidual
is
not
None
else
0
,
dy1
.
stride
(
0
)
if
dy1
is
not
None
else
0
,
dx1
.
stride
(
0
)
if
dx1
is
not
None
else
0
,
dresidual_in
.
stride
(
0
)
if
dresidual_in
is
not
None
else
0
,
M
,
N
,
eps
,
dropout_p
,
rows_per_program
,
is_rms_norm
,
BLOCK_N
,
dresidual
is
not
None
,
dresidual_in
is
not
None
,
bias
is
not
None
,
dropout_p
>
0.0
,
)
dw
=
_dw
.
sum
(
0
).
to
(
weight
.
dtype
)
db
=
_db
.
sum
(
0
).
to
(
bias
.
dtype
)
if
bias
is
not
None
else
None
dw1
=
_dw1
.
sum
(
0
).
to
(
weight1
.
dtype
)
if
weight1
is
not
None
else
None
db1
=
_db1
.
sum
(
0
).
to
(
bias1
.
dtype
)
if
bias1
is
not
None
else
None
# Don't need to compute dresidual_in separately in this case
if
has_residual
and
dx
.
dtype
==
x
.
dtype
and
dropout_p
==
0.0
and
rowscale
is
None
:
dresidual_in
=
dx
if
has_x1
and
dropout_p
==
0.0
:
dx1
=
dx
return
(
(
dx
,
dw
,
db
,
dresidual_in
,
dx1
,
dw1
,
db1
)
if
not
recompute_output
else
(
dx
,
dw
,
db
,
dresidual_in
,
dx1
,
dw1
,
db1
,
y
)
)
class
LayerNormFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
return_dropout_mask
=
False
,
):
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
if
residual
is
not
None
:
assert
residual
.
shape
==
x_shape_og
residual
=
residual
.
reshape
(
-
1
,
residual
.
shape
[
-
1
])
if
residual
.
stride
(
-
1
)
!=
1
:
residual
=
residual
.
contiguous
()
if
x1
is
not
None
:
assert
x1
.
shape
==
x_shape_og
assert
rowscale
is
None
,
"rowscale is not supported with parallel LayerNorm"
x1
=
x1
.
reshape
(
-
1
,
x1
.
shape
[
-
1
])
if
x1
.
stride
(
-
1
)
!=
1
:
x1
=
x1
.
contiguous
()
weight
=
weight
.
contiguous
()
if
bias
is
not
None
:
bias
=
bias
.
contiguous
()
if
weight1
is
not
None
:
weight1
=
weight1
.
contiguous
()
if
bias1
is
not
None
:
bias1
=
bias1
.
contiguous
()
if
rowscale
is
not
None
:
rowscale
=
rowscale
.
reshape
(
-
1
).
contiguous
()
residual_dtype
=
(
residual
.
dtype
if
residual
is
not
None
else
(
torch
.
float32
if
residual_in_fp32
else
None
)
)
y
,
y1
,
mean
,
rstd
,
residual_out
,
seeds
,
dropout_mask
,
dropout_mask1
=
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
,
x1
,
weight1
,
bias1
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
residual_dtype
=
residual_dtype
,
is_rms_norm
=
is_rms_norm
,
return_dropout_mask
=
return_dropout_mask
,
)
ctx
.
save_for_backward
(
residual_out
,
weight
,
bias
,
weight1
,
bias1
,
rowscale
,
seeds
,
mean
,
rstd
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
eps
=
eps
ctx
.
dropout_p
=
dropout_p
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_residual
=
residual
is
not
None
ctx
.
has_x1
=
x1
is
not
None
ctx
.
prenorm
=
prenorm
ctx
.
x_dtype
=
x
.
dtype
y
=
y
.
reshape
(
x_shape_og
)
y1
=
y1
.
reshape
(
x_shape_og
)
if
y1
is
not
None
else
None
residual_out
=
residual_out
.
reshape
(
x_shape_og
)
if
residual_out
is
not
None
else
None
dropout_mask
=
dropout_mask
.
reshape
(
x_shape_og
)
if
dropout_mask
is
not
None
else
None
dropout_mask1
=
dropout_mask1
.
reshape
(
x_shape_og
)
if
dropout_mask1
is
not
None
else
None
if
not
return_dropout_mask
:
if
weight1
is
None
:
return
y
if
not
prenorm
else
(
y
,
residual_out
)
else
:
return
(
y
,
y1
)
if
not
prenorm
else
(
y
,
y1
,
residual_out
)
else
:
if
weight1
is
None
:
return
(
(
y
,
dropout_mask
,
dropout_mask1
)
if
not
prenorm
else
(
y
,
residual_out
,
dropout_mask
,
dropout_mask1
)
)
else
:
return
(
(
y
,
y1
,
dropout_mask
,
dropout_mask1
)
if
not
prenorm
else
(
y
,
y1
,
residual_out
,
dropout_mask
,
dropout_mask1
)
)
@
staticmethod
def
backward
(
ctx
,
dy
,
*
args
):
x
,
weight
,
bias
,
weight1
,
bias1
,
rowscale
,
seeds
,
mean
,
rstd
=
ctx
.
saved_tensors
dy
=
dy
.
reshape
(
-
1
,
dy
.
shape
[
-
1
])
if
dy
.
stride
(
-
1
)
!=
1
:
dy
=
dy
.
contiguous
()
assert
dy
.
shape
==
x
.
shape
if
weight1
is
not
None
:
dy1
,
args
=
args
[
0
],
args
[
1
:]
dy1
=
dy1
.
reshape
(
-
1
,
dy1
.
shape
[
-
1
])
if
dy1
.
stride
(
-
1
)
!=
1
:
dy1
=
dy1
.
contiguous
()
assert
dy1
.
shape
==
x
.
shape
else
:
dy1
=
None
if
ctx
.
prenorm
:
dresidual
=
args
[
0
]
dresidual
=
dresidual
.
reshape
(
-
1
,
dresidual
.
shape
[
-
1
])
if
dresidual
.
stride
(
-
1
)
!=
1
:
dresidual
=
dresidual
.
contiguous
()
assert
dresidual
.
shape
==
x
.
shape
else
:
dresidual
=
None
dx
,
dw
,
db
,
dresidual_in
,
dx1
,
dw1
,
db1
=
_layer_norm_bwd
(
dy
,
x
,
weight
,
bias
,
ctx
.
eps
,
mean
,
rstd
,
dresidual
,
dy1
,
weight1
,
bias1
,
seeds
,
ctx
.
dropout_p
,
rowscale
,
ctx
.
has_residual
,
ctx
.
has_x1
,
ctx
.
is_rms_norm
,
x_dtype
=
ctx
.
x_dtype
,
)
return
(
dx
.
reshape
(
ctx
.
x_shape_og
),
dw
,
db
,
dresidual_in
.
reshape
(
ctx
.
x_shape_og
)
if
ctx
.
has_residual
else
None
,
dx1
.
reshape
(
ctx
.
x_shape_og
)
if
dx1
is
not
None
else
None
,
dw1
,
db1
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
def
layer_norm_fn
(
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
return_dropout_mask
=
False
,
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
residual
,
x1
,
weight1
,
bias1
,
eps
,
dropout_p
,
rowscale
,
prenorm
,
residual_in_fp32
,
is_rms_norm
,
return_dropout_mask
,
)
def
rms_norm_fn
(
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
residual
,
x1
,
weight1
,
bias1
,
eps
,
dropout_p
,
rowscale
,
prenorm
,
residual_in_fp32
,
True
,
return_dropout_mask
,
)
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
dropout_p
=
0.0
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
if
dropout_p
>
0.0
:
self
.
drop
=
torch
.
nn
.
Dropout
(
dropout_p
)
else
:
self
.
drop
=
None
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
"bias"
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
torch
.
nn
.
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x
,
residual
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
):
return
rms_norm_fn
(
x
,
self
.
weight
,
self
.
bias
,
residual
=
residual
,
eps
=
self
.
eps
,
dropout_p
=
self
.
drop
.
p
if
self
.
drop
is
not
None
and
self
.
training
else
0.0
,
prenorm
=
prenorm
,
residual_in_fp32
=
residual_in_fp32
,
)
class
LayerNormLinearFn
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
norm_weight
,
norm_bias
,
linear_weight
,
linear_bias
,
residual
=
None
,
eps
=
1e-6
,
prenorm
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
if
residual
is
not
None
:
assert
residual
.
shape
==
x_shape_og
residual
=
residual
.
reshape
(
-
1
,
residual
.
shape
[
-
1
])
if
residual
.
stride
(
-
1
)
!=
1
:
residual
=
residual
.
contiguous
()
norm_weight
=
norm_weight
.
contiguous
()
if
norm_bias
is
not
None
:
norm_bias
=
norm_bias
.
contiguous
()
residual_dtype
=
(
residual
.
dtype
if
residual
is
not
None
else
(
torch
.
float32
if
residual_in_fp32
else
None
)
)
y
,
_
,
mean
,
rstd
,
residual_out
,
*
rest
=
_layer_norm_fwd
(
x
,
norm_weight
,
norm_bias
,
eps
,
residual
,
out_dtype
=
None
if
not
torch
.
is_autocast_enabled
()
else
torch
.
get_autocast_gpu_dtype
(),
residual_dtype
=
residual_dtype
,
is_rms_norm
=
is_rms_norm
,
)
y
=
y
.
reshape
(
x_shape_og
)
dtype
=
torch
.
get_autocast_gpu_dtype
()
if
torch
.
is_autocast_enabled
()
else
y
.
dtype
linear_weight
=
linear_weight
.
to
(
dtype
)
linear_bias
=
linear_bias
.
to
(
dtype
)
if
linear_bias
is
not
None
else
None
out
=
F
.
linear
(
y
.
to
(
linear_weight
.
dtype
),
linear_weight
,
linear_bias
)
# We don't store y, will be recomputed in the backward pass to save memory
ctx
.
save_for_backward
(
residual_out
,
norm_weight
,
norm_bias
,
linear_weight
,
mean
,
rstd
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
eps
=
eps
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_residual
=
residual
is
not
None
ctx
.
prenorm
=
prenorm
ctx
.
x_dtype
=
x
.
dtype
ctx
.
linear_bias_is_none
=
linear_bias
is
None
return
out
if
not
prenorm
else
(
out
,
residual_out
.
reshape
(
x_shape_og
))
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
dout
,
*
args
):
x
,
norm_weight
,
norm_bias
,
linear_weight
,
mean
,
rstd
=
ctx
.
saved_tensors
dout
=
dout
.
reshape
(
-
1
,
dout
.
shape
[
-
1
])
dy
=
F
.
linear
(
dout
,
linear_weight
.
t
())
dlinear_bias
=
None
if
ctx
.
linear_bias_is_none
else
dout
.
sum
(
0
)
if
dy
.
stride
(
-
1
)
!=
1
:
dy
=
dy
.
contiguous
()
assert
dy
.
shape
==
x
.
shape
if
ctx
.
prenorm
:
dresidual
=
args
[
0
]
dresidual
=
dresidual
.
reshape
(
-
1
,
dresidual
.
shape
[
-
1
])
if
dresidual
.
stride
(
-
1
)
!=
1
:
dresidual
=
dresidual
.
contiguous
()
assert
dresidual
.
shape
==
x
.
shape
else
:
dresidual
=
None
dx
,
dnorm_weight
,
dnorm_bias
,
dresidual_in
,
_
,
_
,
_
,
y
=
_layer_norm_bwd
(
dy
,
x
,
norm_weight
,
norm_bias
,
ctx
.
eps
,
mean
,
rstd
,
dresidual
=
dresidual
,
has_residual
=
ctx
.
has_residual
,
is_rms_norm
=
ctx
.
is_rms_norm
,
x_dtype
=
ctx
.
x_dtype
,
recompute_output
=
True
,
)
dlinear_weight
=
torch
.
einsum
(
"bo,bi->oi"
,
dout
,
y
)
return
(
dx
.
reshape
(
ctx
.
x_shape_og
),
dnorm_weight
,
dnorm_bias
,
dlinear_weight
,
dlinear_bias
,
dresidual_in
.
reshape
(
ctx
.
x_shape_og
)
if
ctx
.
has_residual
else
None
,
None
,
None
,
None
,
None
,
)
def
layer_norm_linear_fn
(
x
,
norm_weight
,
norm_bias
,
linear_weight
,
linear_bias
,
residual
=
None
,
eps
=
1e-6
,
prenorm
=
False
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
return
LayerNormLinearFn
.
apply
(
x
,
norm_weight
,
norm_bias
,
linear_weight
,
linear_bias
,
residual
,
eps
,
prenorm
,
residual_in_fp32
,
is_rms_norm
,
)
vllm_flash_attn/ops/triton/linear.py
deleted
100644 → 0
View file @
6ac8e63a
# Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
from
typing
import
Optional
import
torch
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
flash_attn.ops.triton.k_activations
import
(
gelu
,
gelu_approx
,
gelu_approx_grad
,
gelu_grad
,
squared_relu
,
squared_relu_grad
,
)
# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
def
init_to_zero
(
name
):
return
lambda
nargs
:
nargs
[
name
].
zero_
()
def
get_configs_io_bound
():
configs
=
[]
for
num_stages
in
[
2
,
3
,
4
,
5
,
6
]:
for
block_m
in
[
16
,
32
]:
for
block_k
in
[
32
,
64
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
num_warps
=
2
if
block_n
<=
64
else
4
configs
.
append
(
triton
.
Config
(
{
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"BLOCK_K"
:
block_k
,
"SPLIT_K"
:
1
,
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
)
)
# split_k not used
# for split_k in [2, 4, 8, 16]:
# configs.append(triton.Config(
# {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
# num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return
configs
@
triton
.
autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
]
+
get_configs_io_bound
(),
key
=
[
"CACHE_KEY_M"
,
"CACHE_KEY_N"
,
"CACHE_KEY_K"
],
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
,
},
)
@
triton
.
heuristics
(
{
"EVEN_K"
:
lambda
args
:
args
[
"K"
]
%
(
args
[
"BLOCK_K"
]
*
args
[
"SPLIT_K"
])
==
0
,
}
)
@
triton
.
jit
def
kernel_fwd
(
C
,
# Pointers to matrices
ACT_INPUT
,
A
,
B
,
bias
,
# Matrix dimensions
M
,
N
,
K
,
CACHE_KEY_M
,
CACHE_KEY_N
,
CACHE_KEY_K
,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_cm
,
# stride_cn, # Assume that stride_cn == 1
stride_am
,
stride_ak
,
stride_bn
,
stride_bk
,
# Meta-parameters
BLOCK_M
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
# split k not used, not performant with activation, kept because early_config_prune is expecting it
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
A_ROWMAJOR
:
tl
.
constexpr
,
B_COLMAJOR
:
tl
.
constexpr
,
BIAS
:
tl
.
constexpr
,
SAVE_ACT_INPUT
:
tl
.
constexpr
,
ACTIVATION
:
tl
.
constexpr
,
):
"""
Kernel for computing Out = activation(A x W + C)
- Input has shape (M, K)
- Weight has shape (K, N)
- Bias has shape (N,)
- Output has shape (M, N)
- ActInputs (optional) has shape (M, N)
'ActInputs' optionally saves the A x W + C intermediate for backward computations
This kernel will consolidate over K
"""
pid
=
tl
.
program_id
(
axis
=
0
)
grid_m
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
grid_n
=
(
N
+
BLOCK_N
-
1
)
//
BLOCK_N
# re-order program ID for better L2 performance
width
=
GROUP_M
*
grid_n
group_id
=
pid
//
width
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
pid_n
=
(
pid
%
width
)
//
(
group_size
)
# now compute the block that each program will go through
# rm (resp. rn) denotes a range of indices
# for rows (resp. col) of C
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# trick to avoid masking on M and N axis
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rm
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rn
%
N
,
BLOCK_N
),
BLOCK_N
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
)
if
A_ROWMAJOR
:
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:])
else
:
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:]
*
stride_ak
)
if
B_COLMAJOR
:
B
=
B
+
(
rk
[:,
None
]
+
rbn
[
None
,
:]
*
stride_bn
)
else
:
B
=
B
+
(
rk
[:,
None
]
*
stride_bk
+
rbn
[
None
,
:]
*
stride_bn
)
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
K
,
0
,
-
BLOCK_K
):
if
EVEN_K
:
a
=
tl
.
load
(
A
)
b
=
tl
.
load
(
B
)
else
:
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k
,
other
=
0.0
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k
,
other
=
0.0
)
acc
+=
tl
.
dot
(
a
,
b
)
if
A_ROWMAJOR
:
A
+=
BLOCK_K
else
:
A
+=
BLOCK_K
*
stride_ak
if
B_COLMAJOR
:
B
+=
BLOCK_K
else
:
B
+=
BLOCK_K
*
stride_bk
# Putting bias after the matmul (instead of before) is faster, idk why
if
BIAS
:
bias
=
tl
.
load
(
bias
+
rn
,
mask
=
rn
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
acc
+=
bias
[
None
,
:]
# optional: save the activation inputs
if
SAVE_ACT_INPUT
:
# act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn
act_in_ptrs
=
ACT_INPUT
+
ram
[:,
None
]
*
stride_cm
+
rbn
[
None
,
:]
tl
.
store
(
act_in_ptrs
,
acc
)
# optional: fused activation (while the data is in shared memory)
if
ACTIVATION
==
"gelu"
:
acc
=
gelu
(
acc
)
elif
ACTIVATION
==
"gelu_approx"
:
acc
=
gelu_approx
(
acc
)
elif
ACTIVATION
==
"squared_relu"
:
acc
=
squared_relu
(
acc
)
# rematerialize rm and rn to save registers
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# write back result
# C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
C
=
C
+
rm
[:,
None
]
*
stride_cm
+
rn
[
None
,
:]
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
tl
.
store
(
C
,
acc
)
def
triton_linear_act
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"id"
,
save_act_input
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Compute e = activation(x @ weight.T + bias).
This wrapper kicks the `kernel_fwd` Triton kernel
:param x: input tensor
:param weight: weight matrix
:param bias: an optional bias tensor
:param activation: Activation name. Needs to be a Triton kernel.
:param act_input: an optional tensor to save the activation inputs (for backward)
:return: result tensor
"""
# if torch.is_autocast_enabled():
# dtype = torch.get_autocast_gpu_dtype()
# x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
assert
activation
in
[
"id"
,
"gelu"
,
"gelu_approx"
,
"squared_relu"
]
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
x_reshaped
=
x
.
reshape
(
batch_dim
,
n
)
if
x_reshaped
.
stride
(
0
)
>
1
and
x_reshaped
.
stride
(
1
)
>
1
:
x_reshaped
=
x_reshaped
.
contiguous
()
if
weight
.
stride
(
0
)
>
1
and
weight
.
stride
(
1
)
>
1
:
weight
=
weight
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
assert
(
x
.
dtype
==
weight
.
dtype
),
f
"Input and weight must have the same dtype, got
{
x
.
dtype
}
and
{
weight
.
dtype
}
"
if
bias
is
not
None
:
assert
(
x
.
dtype
==
bias
.
dtype
),
f
"Input and bias must have the same dtype, got
{
x
.
dtype
}
and
{
bias
.
dtype
}
"
assert
(
x_reshaped
.
shape
[
1
]
==
weight
.
shape
[
1
]
),
f
"Incompatible dimensions:
{
x_reshaped
.
shape
}
-
{
weight
.
shape
}
"
assert
(
bias
is
None
or
bias
.
shape
[
0
]
==
weight
.
shape
[
0
]
),
"Incompatible dimensions in between weight and bias"
M
,
K
=
x_reshaped
.
shape
N
,
K
=
weight
.
shape
output
=
torch
.
empty
((
M
,
N
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
act_input
=
torch
.
empty_like
(
output
)
if
save_act_input
else
None
# 1D launch kernel where each block gets its own program.
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_N"
]),)
# noqa
kernel_fwd
[
grid
](
output
,
act_input
,
x_reshaped
,
weight
,
# data ptrs
bias
if
bias
is
not
None
else
x
,
# auto skip bias if not present
M
,
# shapes
N
,
K
,
M
//
32
,
# key for triton cache (limit number of compilations)
N
//
32
,
K
//
32
,
stride_cm
=
output
.
stride
(
0
),
# strides
# stride_cn=output.stride(1),
stride_am
=
x_reshaped
.
stride
(
0
),
stride_ak
=
x_reshaped
.
stride
(
1
),
stride_bk
=
weight
.
stride
(
1
),
stride_bn
=
weight
.
stride
(
0
),
BIAS
=
bias
is
not
None
,
# optional fused bias
SAVE_ACT_INPUT
=
save_act_input
,
# optional save activation inputs
ACTIVATION
=
activation
,
# optional fused activation
A_ROWMAJOR
=
x_reshaped
.
stride
(
1
)
==
1
,
B_COLMAJOR
=
weight
.
stride
(
1
)
==
1
,
GROUP_M
=
8
,
# speed optimization: group the programs
)
if
not
save_act_input
:
return
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
])
else
:
return
(
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
]),
act_input
.
reshape
(
*
batch_shape
,
act_input
.
shape
[
-
1
]),
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
]
+
get_configs_io_bound
(),
key
=
[
"CACHE_KEY_M"
,
"CACHE_KEY_N"
,
"CACHE_KEY_K"
],
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
,
},
)
@
triton
.
heuristics
(
{
"EVEN_K"
:
lambda
args
:
args
[
"K"
]
%
(
args
[
"BLOCK_K"
]
*
args
[
"SPLIT_K"
])
==
0
,
}
)
@
triton
.
jit
def
kernel_bwd
(
C
,
# Pointers to matrices
ACT_INPUT
,
A
,
B
,
# Matrix dimensions
M
,
N
,
K
,
CACHE_KEY_M
,
CACHE_KEY_N
,
CACHE_KEY_K
,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_cm
,
# stride_cn, # Assume that stride_cn == 1
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
# Meta-parameters
BLOCK_M
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
# split k not used, not performant with activation, kept because early_config_prune is expecting it
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
ACTIVATION
:
tl
.
constexpr
,
):
"""
Kernel for computing Out = activation(A x W + C)
- Input has shape (M, K)
- Weight has shape (K, N)
- Output has shape (M, N)
- ActInputs (optional) has shape (M, N)
'ActInputs' optionally saves the A x W + C intermediate for backward computations
This kernel will consolidate over K
"""
pid
=
tl
.
program_id
(
axis
=
0
)
grid_m
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
grid_n
=
(
N
+
BLOCK_N
-
1
)
//
BLOCK_N
# re-order program ID for better L2 performance
width
=
GROUP_M
*
grid_n
group_id
=
pid
//
width
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
pid_n
=
(
pid
%
width
)
//
(
group_size
)
# now compute the block that each program will go through
# rm (resp. rn) denotes a range of indices
# for rows (resp. col) of C
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# trick to avoid masking on M and N axis
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rm
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rn
%
N
,
BLOCK_N
),
BLOCK_N
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
)
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:]
*
stride_ak
)
B
=
B
+
(
rk
[:,
None
]
*
stride_bk
+
rbn
[
None
,
:]
*
stride_bn
)
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
K
,
0
,
-
BLOCK_K
):
if
EVEN_K
:
a
=
tl
.
load
(
A
)
b
=
tl
.
load
(
B
)
else
:
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k
,
other
=
0.0
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k
,
other
=
0.0
)
acc
+=
tl
.
dot
(
a
,
b
)
A
+=
BLOCK_K
*
stride_ak
B
+=
BLOCK_K
*
stride_bk
# optional: fused activation (while the data is in shared memory)
if
ACTIVATION
!=
"id"
:
act_in_ptrs
=
ACT_INPUT
+
ram
[:,
None
]
*
stride_cm
+
rbn
[
None
,
:]
act_input
=
tl
.
load
(
act_in_ptrs
).
to
(
acc
.
dtype
)
if
ACTIVATION
==
"gelu"
:
acc
*=
gelu_grad
(
act_input
)
elif
ACTIVATION
==
"gelu_approx"
:
acc
*=
gelu_approx_grad
(
act_input
)
elif
ACTIVATION
==
"squared_relu"
:
acc
*=
squared_relu_grad
(
act_input
)
# rematerialize rm and rn to save registers
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# write back result
C
=
C
+
rm
[:,
None
]
*
stride_cm
+
rn
[
None
,
:]
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
tl
.
store
(
C
,
acc
,
mask
=
mask
)
def
triton_dgrad_act
(
grad_output
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
activation
:
str
=
"id"
,
act_input
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
Compute e = activation(grad_output @ weight + bias).
This wrapper kicks the `kernel_fwd` Triton kernel
:param grad_output: input tensor
:param weight: weight matrix
:param activation: Activation name. Needs to be a Triton kernel.
:param act_input: an optional tensor to save the activation inputs (for backward)
:return: result tensor
"""
assert
activation
in
[
"id"
,
"gelu"
,
"gelu_approx"
,
"squared_relu"
]
batch_shape
,
n
=
grad_output
.
shape
[:
-
1
],
grad_output
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
grad_output_reshaped
=
grad_output
.
reshape
(
batch_dim
,
n
)
if
grad_output_reshaped
.
stride
(
0
)
>
1
and
grad_output_reshaped
.
stride
(
1
)
>
1
:
grad_output_reshaped
=
grad_output_reshaped
.
contiguous
()
if
weight
.
stride
(
0
)
>
1
and
weight
.
stride
(
1
)
>
1
:
weight
=
weight
.
contiguous
()
assert
(
grad_output
.
dtype
==
weight
.
dtype
),
f
"grad_output and weight must have the same dtype, got
{
grad_output
.
dtype
}
and
{
weight
.
dtype
}
"
assert
(
grad_output_reshaped
.
shape
[
1
]
==
weight
.
shape
[
0
]
),
f
"Incompatible dimensions:
{
grad_output_reshaped
.
shape
}
-
{
weight
.
shape
}
"
if
activation
!=
"id"
:
assert
act_input
is
not
None
,
f
"act_input is required for activation
{
activation
}
"
# M, N, K in bwd are different from M, N, K in fwd
M
,
K
=
grad_output_reshaped
.
shape
K
,
N
=
weight
.
shape
grad_input
=
torch
.
empty
((
M
,
N
),
device
=
grad_output
.
device
,
dtype
=
grad_output
.
dtype
)
# 1D launch kernel where each block gets its own program.
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_N"
]),)
# noqa
kernel_bwd
[
grid
](
grad_input
,
act_input
,
grad_output_reshaped
,
weight
,
# data ptrs
M
,
# shapes
N
,
K
,
M
//
32
,
# key for triton cache (limit number of compilations)
N
//
32
,
K
//
32
,
stride_cm
=
grad_input
.
stride
(
0
),
# strides
# stride_cn=grad_input.stride(1),
stride_am
=
grad_output_reshaped
.
stride
(
0
),
stride_ak
=
grad_output_reshaped
.
stride
(
1
),
stride_bk
=
weight
.
stride
(
0
),
stride_bn
=
weight
.
stride
(
1
),
ACTIVATION
=
activation
,
# optional fused activation
GROUP_M
=
8
,
# speed optimization: group the programs
)
return
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
vllm_flash_attn/ops/triton/mlp.py
deleted
100644 → 0
View file @
6ac8e63a
# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared
# to naive implementation.
import
fused_dense_lib
as
fused_dense_cuda
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
flash_attn.ops.activations
import
sqrelu_bwd
,
sqrelu_fwd
from
flash_attn.ops.triton.linear
import
triton_dgrad_act
,
triton_linear_act
class
FusedDenseSqreluDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
checkpoint_lvl
=
0
):
"""checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute act_input and gelu_out in the bwd
"""
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight1
,
bias1
,
weight2
,
bias2
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight1
,
bias1
,
weight2
,
bias2
]
]
is_bf16
=
x
.
dtype
==
torch
.
bfloat16
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
x
=
x
.
contiguous
()
weight1
=
weight1
.
contiguous
()
bias1
=
bias1
.
contiguous
()
weight2
=
weight2
.
contiguous
()
bias2
=
bias2
.
contiguous
()
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
if
is_bf16
:
act_input
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
)
output1
=
sqrelu_fwd
(
act_input
)
else
:
save_act_input
=
checkpoint_lvl
!=
2
result
=
triton_linear_act
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
activation
=
"squared_relu"
,
save_act_input
=
save_act_input
,
)
if
save_act_input
:
output1
,
act_input
=
result
else
:
output1
=
result
output2
=
fused_dense_cuda
.
linear_bias_forward
(
output1
,
weight2
,
bias2
)
ctx
.
checkpoint_lvl
=
checkpoint_lvl
if
checkpoint_lvl
==
0
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
,
act_input
,
output1
)
elif
checkpoint_lvl
==
1
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
,
act_input
)
elif
checkpoint_lvl
==
2
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
)
return
output2
.
reshape
(
*
batch_shape
,
output2
.
shape
[
-
1
])
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
grad_output
=
grad_output
.
contiguous
()
checkpoint_lvl
=
ctx
.
checkpoint_lvl
x
,
weight1
,
bias1
,
weight2
,
*
rest
=
ctx
.
saved_tensors
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
is_bf16
=
x
.
dtype
==
torch
.
bfloat16
if
checkpoint_lvl
==
0
:
act_input
,
output1
=
rest
elif
checkpoint_lvl
==
1
:
(
act_input
,)
=
rest
output1
=
sqrelu_fwd
(
act_input
)
elif
checkpoint_lvl
==
2
:
if
is_bf16
:
act_input
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
)
output1
=
sqrelu_fwd
(
act_input
)
else
:
output1
,
act_input
=
triton_linear_act
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
activation
=
"squared_relu"
,
save_act_input
=
True
,
)
if
is_bf16
:
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
)
grad_output1
=
grad_output
@
weight2
grad_act_input
=
sqrelu_bwd
(
grad_output1
,
act_input
)
grad_input
,
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_backward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
grad_act_input
)
else
:
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
)
grad_act_input
=
triton_dgrad_act
(
grad_output
,
weight2
,
activation
=
"squared_relu"
,
act_input
=
act_input
)
grad_input
,
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_backward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
grad_act_input
)
return
grad_input
.
reshape_as
(
x
),
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
fused_dense_sqrelu_dense_function
=
FusedDenseSqreluDenseFunc
.
apply
class
FusedDenseSqreluDense
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
bias1
=
True
,
bias2
=
True
,
checkpoint_lvl
=
0
,
device
=
None
,
dtype
=
None
,
):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
*
4
assert
bias1
==
True
,
"DenseSqreluDense module without bias is currently not supported"
assert
bias2
==
True
,
"DenseSqreluDense module without bias is currently not supported"
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
assert
x
.
is_cuda
return
fused_dense_sqrelu_dense_function
(
x
,
self
.
fc1
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
weight
,
self
.
fc2
.
bias
,
self
.
checkpoint_lvl
)
vllm_flash_attn/ops/triton/rotary.py
deleted
100644 → 0
View file @
6ac8e63a
# Copyright (c) 2023, Tri Dao.
from
typing
import
Optional
,
Union
import
torch
import
triton
import
triton.language
as
tl
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_M": 2}),
# triton.Config({"BLOCK_M": 4}),
# triton.Config({"BLOCK_M": 8}),
# triton.Config({"BLOCK_M": 16}),
# ],
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
# )
@
triton
.
jit
def
rotary_kernel
(
OUT
,
# Pointers to matrices
X
,
COS
,
SIN
,
CU_SEQLENS
,
SEQLEN_OFFSETS
,
# this could be int or a pointer
# Matrix dimensions
seqlen
,
nheads
,
rotary_dim
,
seqlen_ro
,
CACHE_KEY_SEQLEN
,
# strides
stride_out_batch
,
stride_out_seqlen
,
stride_out_nheads
,
stride_out_headdim
,
stride_x_batch
,
stride_x_seqlen
,
stride_x_nheads
,
stride_x_headdim
,
# Meta-parameters
BLOCK_K
:
tl
.
constexpr
,
IS_SEQLEN_OFFSETS_TENSOR
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
INTERLEAVED
:
tl
.
constexpr
,
CONJUGATE
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
):
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_batch
=
tl
.
program_id
(
axis
=
1
)
pid_head
=
tl
.
program_id
(
axis
=
2
)
rotary_dim_half
=
rotary_dim
//
2
if
not
IS_VARLEN
:
X
=
X
+
pid_batch
*
stride_x_batch
+
pid_head
*
stride_x_nheads
OUT
=
OUT
+
pid_batch
*
stride_out_batch
+
pid_head
*
stride_out_nheads
else
:
start_idx
=
tl
.
load
(
CU_SEQLENS
+
pid_batch
)
seqlen
=
tl
.
load
(
CU_SEQLENS
+
pid_batch
+
1
)
-
start_idx
X
=
X
+
start_idx
*
stride_x_seqlen
+
pid_head
*
stride_x_nheads
OUT
=
OUT
+
start_idx
*
stride_out_seqlen
+
pid_head
*
stride_out_nheads
if
pid_m
*
BLOCK_M
>=
seqlen
:
return
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
if
not
IS_SEQLEN_OFFSETS_TENSOR
:
rm_cs
=
rm
+
SEQLEN_OFFSETS
else
:
rm_cs
=
rm
+
tl
.
load
(
SEQLEN_OFFSETS
+
pid_batch
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
)
rk_half
=
tl
.
arange
(
0
,
BLOCK_K
//
2
)
if
not
INTERLEAVED
:
# Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
X
=
X
+
(
rm
[:,
None
]
*
stride_x_seqlen
+
rk_half
[
None
,
:]
*
stride_x_headdim
)
COS
=
COS
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk_half
[
None
,
:])
SIN
=
SIN
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk_half
[
None
,
:])
cos
=
tl
.
load
(
COS
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
other
=
1.0
).
to
(
tl
.
float32
)
sin
=
tl
.
load
(
SIN
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
).
to
(
tl
.
float32
)
x0
=
tl
.
load
(
X
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
).
to
(
tl
.
float32
)
x1
=
tl
.
load
(
X
+
rotary_dim_half
*
stride_x_headdim
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
,
).
to
(
tl
.
float32
)
if
CONJUGATE
:
sin
=
-
sin
o0
=
x0
*
cos
-
x1
*
sin
o1
=
x0
*
sin
+
x1
*
cos
# write back result
OUT
=
OUT
+
(
rm
[:,
None
]
*
stride_out_seqlen
+
rk_half
[
None
,
:]
*
stride_out_headdim
)
tl
.
store
(
OUT
,
o0
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
))
tl
.
store
(
OUT
+
rotary_dim_half
*
stride_out_headdim
,
o1
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
)
else
:
# We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
# Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
# Loading x0 will be fast but x1 will be slow.
# Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
# Then we do the calculation and use tl.where to pick put the right outputs for the even
# and for the odd indices.
rk_swap
=
rk
+
((
rk
+
1
)
%
2
)
*
2
-
1
# 1, 0, 3, 2, 5, 4, ...
rk_repeat
=
tl
.
arange
(
0
,
BLOCK_K
)
//
2
X0
=
X
+
(
rm
[:,
None
]
*
stride_x_seqlen
+
rk
[
None
,
:]
*
stride_x_headdim
)
X1
=
X
+
(
rm
[:,
None
]
*
stride_x_seqlen
+
rk_swap
[
None
,
:]
*
stride_x_headdim
)
COS
=
COS
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk_repeat
[
None
,
:])
SIN
=
SIN
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk_repeat
[
None
,
:])
cos
=
tl
.
load
(
COS
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk_repeat
[
None
,
:]
<
rotary_dim_half
),
other
=
1.0
,
).
to
(
tl
.
float32
)
sin
=
tl
.
load
(
SIN
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk_repeat
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
,
).
to
(
tl
.
float32
)
x0
=
tl
.
load
(
X0
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim
),
other
=
0.0
).
to
(
tl
.
float32
)
x1
=
tl
.
load
(
X1
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_swap
[
None
,
:]
<
rotary_dim
),
other
=
0.0
).
to
(
tl
.
float32
)
if
CONJUGATE
:
sin
=
-
sin
x0_cos
=
x0
*
cos
x1_sin
=
x1
*
sin
out
=
tl
.
where
(
rk
[
None
,
:]
%
2
==
0
,
x0_cos
-
x1_sin
,
x0_cos
+
x1_sin
)
OUT
=
OUT
+
(
rm
[:,
None
]
*
stride_out_seqlen
+
rk
[
None
,
:]
*
stride_out_headdim
)
tl
.
store
(
OUT
,
out
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim
))
def
apply_rotary
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
max_seqlen
:
Optional
[
int
]
=
None
,
interleaved
=
False
,
inplace
=
False
,
conjugate
=
False
,
)
->
torch
.
Tensor
:
"""
Arguments:
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim).
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Returns:
y: (batch, seqlen, nheads, headdim)
"""
is_varlen
=
cu_seqlens
is
not
None
if
not
is_varlen
:
batch
,
seqlen
,
nheads
,
headdim
=
x
.
shape
else
:
assert
max_seqlen
is
not
None
,
"If cu_seqlens is passed in, then max_seqlen must be passed"
total_seqlen
,
nheads
,
headdim
=
x
.
shape
batch_p_1
=
cu_seqlens
.
shape
[
0
]
batch
=
batch_p_1
-
1
seqlen
=
max_seqlen
seqlen_ro
,
rotary_dim
=
cos
.
shape
assert
sin
.
shape
==
cos
.
shape
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
,
"rotary_dim must be <= headdim"
assert
headdim
<=
256
,
"Only support headdim <= 256"
assert
seqlen_ro
>=
seqlen
,
"seqlen_ro must be >= seqlen"
assert
(
cos
.
dtype
==
sin
.
dtype
),
f
"cos and sin must have the same dtype, got
{
cos
.
dtype
}
and
{
sin
.
dtype
}
"
assert
(
x
.
dtype
==
cos
.
dtype
),
f
"Input and cos/sin must have the same dtype, got
{
x
.
dtype
}
and
{
cos
.
dtype
}
"
cos
,
sin
=
cos
.
contiguous
(),
sin
.
contiguous
()
if
isinstance
(
seqlen_offsets
,
torch
.
Tensor
):
assert
seqlen_offsets
.
shape
==
(
batch
,)
assert
seqlen_offsets
.
dtype
in
[
torch
.
int32
,
torch
.
int64
]
seqlen_offsets
=
seqlen_offsets
.
contiguous
()
else
:
assert
seqlen_offsets
+
seqlen
<=
seqlen_ro
output
=
torch
.
empty_like
(
x
)
if
not
inplace
else
x
if
rotary_dim
<
headdim
and
not
inplace
:
output
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
BLOCK_K
=
(
32
if
rotary_dim
<=
32
else
(
64
if
rotary_dim
<=
64
else
(
128
if
rotary_dim
<=
128
else
256
))
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen
,
META
[
"BLOCK_M"
]),
batch
,
nheads
)
# noqa
BLOCK_M
=
4
if
interleaved
else
(
8
if
rotary_dim
<=
64
else
4
)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
rotary_kernel
[
grid
](
output
,
# data ptrs
x
,
cos
,
sin
,
cu_seqlens
,
seqlen_offsets
,
seqlen
,
# shapes
nheads
,
rotary_dim
,
seqlen_ro
,
seqlen
//
128
,
# key for triton cache (limit number of compilations)
output
.
stride
(
0
)
if
not
is_varlen
else
0
,
# batch_strides if not varlen else 0
output
.
stride
(
-
3
),
# seqlen_stride or total_seqlen_stride
output
.
stride
(
-
2
),
# nheads_stride
output
.
stride
(
-
1
),
# headdim_stride
x
.
stride
(
0
)
if
not
is_varlen
else
0
,
# batch_strides if not varlen else 0
x
.
stride
(
-
3
),
# seqlen stride or total_seqlen_stride
x
.
stride
(
-
2
),
# nheads stride
x
.
stride
(
-
1
),
# headdim stride
BLOCK_K
,
isinstance
(
seqlen_offsets
,
torch
.
Tensor
),
is_varlen
,
interleaved
,
conjugate
,
BLOCK_M
,
)
return
output
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment