Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
2e33fc8e
Commit
2e33fc8e
authored
Nov 13, 2022
by
Tri Dao
Browse files
Add GPT and ViT models
parent
d4b320b3
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1209 additions
and
5 deletions
+1209
-5
README.md
README.md
+3
-3
csrc/fused_dense_lib/README.md
csrc/fused_dense_lib/README.md
+1
-1
csrc/layer_norm/README.md
csrc/layer_norm/README.md
+1
-1
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+174
-0
flash_attn/models/vit.py
flash_attn/models/vit.py
+249
-0
flash_attn/ops/triton/k_activations.py
flash_attn/ops/triton/k_activations.py
+162
-0
flash_attn/ops/triton/linear.py
flash_attn/ops/triton/linear.py
+479
-0
flash_attn/ops/triton/mlp.py
flash_attn/ops/triton/mlp.py
+140
-0
No files found.
README.md
View file @
2e33fc8e
...
@@ -52,7 +52,7 @@ Our tentative roadmap:
...
@@ -52,7 +52,7 @@ Our tentative roadmap:
6.
~~[Jul 2022] Implement cross-attention~~[Done].
6.
~~[Jul 2022] Implement cross-attention~~[Done].
7.
~~[Jul 2022] Support head dimension 128~~[Done].
7.
~~[Jul 2022] Support head dimension 128~~[Done].
8.
[Jul 2022] Support SM70 GPUs (V100).
8.
[Jul 2022] Support SM70 GPUs (V100).
9.
[Aug 2022] Fuse rotary embedding.
9.
~~
[Aug 2022] Fuse rotary embedding
~~[Done]
.
10.
[Aug 2022] Support attention bias (e.g. ALiBi, relative positional encoding).
10.
[Aug 2022] Support attention bias (e.g. ALiBi, relative positional encoding).
## Speedup and Memory Savings
## Speedup and Memory Savings
...
@@ -154,10 +154,10 @@ and for his thoughtful answers to our questions about CUDA.
...
@@ -154,10 +154,10 @@ and for his thoughtful answers to our questions about CUDA.
## Citation
## Citation
If you use this codebase, or otherwise found our work valuable, please cite:
If you use this codebase, or otherwise found our work valuable, please cite:
```
```
@
article
{dao2022flashattention,
@
inproceedings
{dao2022flashattention,
title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
journal={arXiv preprint arXiv:2205.14135
},
booktitle={Advances in Neural Information Processing Systems
},
year={2022}
year={2022}
}
}
```
```
csrc/fused_dense_lib/README.md
View file @
2e33fc8e
This CUDA extension
s
implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu
This CUDA extension implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu
(forward and backward), adapted from Apex's
(forward and backward), adapted from Apex's
[
FusedDense
](
https://github.com/NVIDIA/apex/tree/master/apex/fused_dense
)
.
[
FusedDense
](
https://github.com/NVIDIA/apex/tree/master/apex/fused_dense
)
.
We make it work for bfloat16.
We make it work for bfloat16.
...
...
csrc/layer_norm/README.md
View file @
2e33fc8e
This CUDA extension
s
implements fused dropout + residual + LayerNorm, based on
This CUDA extension implements fused dropout + residual + LayerNorm, based on
Apex's
[
FastLayerNorm
](
https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm
)
.
Apex's
[
FastLayerNorm
](
https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm
)
.
We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
```
sh
```
sh
...
...
flash_attn/models/gpt.py
0 → 100644
View file @
2e33fc8e
# Copyright (c) 2022, Tri Dao.
import
math
from
functools
import
partial
from
collections
import
namedtuple
from
collections.abc
import
Sequence
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
transformers.models.gpt2.configuration_gpt2
import
GPT2Config
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
Mlp
,
FusedDenseGeluDense
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.embedding
import
GPT2Embeddings
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
except
ImportError
:
dropout_add_layer_norm
=
None
try
:
from
flash_attn.ops.triton.mlp
import
FusedDenseSqreluDense
except
ImportError
:
FusedDenseSqreluDense
=
None
def
create_mixer_cls
(
config
,
layer_idx
=
None
):
head_dim
=
getattr
(
config
,
'head_dim'
,
config
.
hidden_size
//
config
.
num_attention_heads
)
softmax_scale
=
1.0
if
not
config
.
scale_attn_weights
else
head_dim
**
(
-
0.5
)
if
config
.
scale_attn_by_inverse_layer_idx
:
assert
layer_idx
is
not
None
softmax_scale
/=
float
(
layer_idx
+
1
)
dwconv
=
getattr
(
config
,
'attn_dwconv'
,
False
)
rotary_emb_dim
=
int
(
getattr
(
config
,
'rotary_emb_fraction'
,
0.0
)
*
head_dim
)
use_flash_attn
=
getattr
(
config
,
'use_flash_attn'
,
False
)
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
mixer_cls
=
partial
(
MHA
,
num_heads
=
config
.
num_attention_heads
,
dropout
=
config
.
attn_pdrop
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
dwconv
=
dwconv
,
rotary_emb_dim
=
rotary_emb_dim
,
fused_bias_fc
=
fused_bias_fc
,
use_flash_attn
=
use_flash_attn
)
return
mixer_cls
def
create_mlp_cls
(
config
,
layer_idx
=
None
):
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
fused_dense_gelu_dense
=
getattr
(
config
,
'fused_dense_gelu_dense'
,
False
)
fused_dense_sqrelu_dense
=
getattr
(
config
,
'fused_dense_sqrelu_dense'
,
False
)
assert
not
(
fused_dense_sqrelu_dense
and
fused_dense_gelu_dense
)
if
not
fused_dense_gelu_dense
and
not
fused_dense_sqrelu_dense
:
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
partial
(
F
.
gelu
,
approximate
=
'tanh'
))
else
:
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if
isinstance
(
mlp_checkpoint_lvl
,
Sequence
):
assert
layer_idx
is
not
None
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
if
fused_dense_gelu_dense
:
mlp_cls
=
partial
(
FusedDenseGeluDense
,
hidden_features
=
inner_dim
,
checkpoint_lvl
=
mlp_checkpoint_lvl
)
elif
fused_dense_sqrelu_dense
:
assert
FusedDenseSqreluDense
is
not
None
mlp_cls
=
partial
(
FusedDenseSqreluDense
,
hidden_features
=
inner_dim
,
checkpoint_lvl
=
mlp_checkpoint_lvl
)
else
:
raise
RuntimeError
(
'MLP type not supported'
)
return
mlp_cls
def
create_block
(
config
,
layer_idx
=
None
):
mixer_cls
=
create_mixer_cls
(
config
,
layer_idx
)
mlp_cls
=
create_mlp_cls
(
config
,
layer_idx
)
norm_cls
=
partial
(
nn
.
LayerNorm
,
eps
=
config
.
layer_norm_epsilon
)
block
=
Block
(
config
.
hidden_size
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_cls
,
prenorm
=
True
,
resid_dropout
=
config
.
resid_pdrop
,
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
))
block
.
layer_idx
=
layer_idx
return
block
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def
_init_weights
(
module
,
n_layer
,
initializer_range
=
0.02
,
rescale_prenorm_residual
=
True
):
if
isinstance
(
module
,
nn
.
Linear
):
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
)
if
module
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
module
.
bias
)
elif
isinstance
(
module
,
nn
.
Embedding
):
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
)
if
rescale_prenorm_residual
:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for
name
,
p
in
module
.
named_parameters
():
if
name
in
[
"out_proj.weight"
,
"fc2.weight"
]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
nn
.
init
.
normal_
(
p
,
mean
=
0.0
,
std
=
initializer_range
/
math
.
sqrt
(
2
*
n_layer
))
class
GPT2Model
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPT2Config
):
super
().
__init__
()
self
.
pad_vocab_size_multiple_8
=
getattr
(
config
,
'pad_vocab_size_multiple_8'
,
False
)
if
self
.
pad_vocab_size_multiple_8
:
if
config
.
vocab_size
%
8
!=
0
:
config
.
vocab_size
+=
8
-
(
config
.
vocab_size
%
8
)
self
.
embeddings
=
GPT2Embeddings
(
config
.
hidden_size
,
config
.
vocab_size
,
config
.
max_position_embeddings
)
self
.
emb_drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
# We change the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and
# the main branch (output of LN). The model definition is unchanged, but the mapping of the
# nn.LayerNorm weights are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
)
if
self
.
fused_dropout_add_ln
and
dropout_add_layer_norm
is
None
:
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
# self.ln_0 is the first layer norm in the model, while self.ln_f (in the pretrained weight)
# is the final layer norm.
self
.
ln_0
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
layers
=
nn
.
ModuleList
([
create_block
(
config
,
layer_idx
=
i
)
for
i
in
range
(
config
.
num_hidden_layers
)])
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
config
.
num_hidden_layers
,
initializer_range
=
config
.
initializer_range
))
def
forward
(
self
,
input_ids
,
position_ids
=
None
):
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
if
not
self
.
fused_dropout_add_ln
:
residual
=
self
.
emb_drop
(
hidden_states
).
float
()
hidden_states
=
self
.
ln_0
(
residual
.
to
(
dtype
=
self
.
ln_0
.
weight
.
dtype
))
else
:
hidden_states
,
residual
=
dropout_add_layer_norm
(
hidden_states
,
None
,
self
.
ln_0
.
weight
,
self
.
ln_0
.
bias
,
self
.
emb_drop
.
p
if
self
.
training
else
0.0
,
self
.
ln_0
.
eps
,
prenorm
=
True
,
residual_in_fp32
=
True
)
for
layer
in
self
.
layers
:
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
)
return
hidden_states
class
GPT2LMHeadModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPT2Config
):
super
().
__init__
()
self
.
transformer
=
GPT2Model
(
config
)
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
,
bias
=
False
)
# Initialize weights and apply final processing
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
config
.
num_hidden_layers
,
initializer_range
=
config
.
initializer_range
))
self
.
tie_weights
()
def
tie_weights
(
self
):
self
.
lm_head
.
weight
=
self
.
transformer
.
embeddings
.
word_embeddings
.
weight
def
forward
(
self
,
input_ids
,
position_ids
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
CausalLMOutput
=
namedtuple
(
'CausalLMOutput'
,
[
'logits'
])
return
CausalLMOutput
(
logits
=
lm_logits
)
flash_attn/models/vit.py
0 → 100644
View file @
2e33fc8e
# Copyright (c) 2022, Tri Dao.
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import
math
from
functools
import
partial
from
copy
import
deepcopy
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn.init
import
trunc_normal_
from
einops
import
rearrange
from
timm.models.helpers
import
named_apply
from
flash_attn.layers.patch_embed
import
PatchEmbed
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
Mlp
,
FusedDenseGeluDense
from
flash_attn.modules.block
import
Block
def
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop
,
use_flash_attn
,
fused_bias_fc
,
cross_attn
=
False
):
mixer_cls
=
partial
(
MHA
,
num_heads
=
num_heads
,
cross_attn
=
cross_attn
,
bias
=
qkv_bias
,
dropout
=
attn_drop
,
fused_bias_fc
=
fused_bias_fc
,
use_flash_attn
=
use_flash_attn
)
return
mixer_cls
def
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_dense_gelu_dense
):
inner_dim
=
int
(
embed_dim
*
mlp_ratio
)
if
not
fused_dense_gelu_dense
:
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
act_layer
())
else
:
mlp_cls
=
partial
(
FusedDenseGeluDense
,
hidden_features
=
inner_dim
)
return
mlp_cls
def
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path
,
norm_layer
,
act_layer
,
use_flash_attn
,
fused_bias_fc
,
fused_dense_gelu_dense
,
fused_dropout_add_ln
,
layer_idx
=
None
,
n_layer
=
None
,
last_layer_subset
=
False
):
mixer_cls
=
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop_rate
,
use_flash_attn
,
fused_bias_fc
,
cross_attn
=
(
last_layer_subset
and
layer_idx
==
n_layer
-
1
))
mlp_cls
=
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_dense_gelu_dense
)
block
=
Block
(
embed_dim
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_layer
,
prenorm
=
True
,
resid_dropout
=
drop_rate
,
drop_path
=
drop_path
,
fused_dropout_add_ln
=
fused_dropout_add_ln
)
return
block
class
VisionTransformer
(
nn
.
Module
):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
num_classes
=
1000
,
global_pool
=
'token'
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
init_values
=
None
,
class_token
=
True
,
no_embed_class
=
False
,
pre_norm
=
False
,
fc_norm
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
weight_init
=
''
,
embed_layer
=
PatchEmbed
,
norm_layer
=
None
,
act_layer
=
None
,
use_flash_attn
=
False
,
fused_bias_fc
=
False
,
fused_dense_gelu_dense
=
False
,
fused_dropout_add_ln
=
False
,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'token')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
init_values: (float): layer-scale init values
class_token (bool): use class token
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
"""
super
().
__init__
()
assert
global_pool
==
'token'
,
'Only support pooling with CLS token'
assert
class_token
assert
init_values
is
None
,
'LayerScale is not supported yet'
assert
weight_init
==
''
assert
fc_norm
is
None
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
assert
not
pre_norm
use_fc_norm
=
global_pool
==
'avg'
if
fc_norm
is
None
else
fc_norm
norm_layer
=
norm_layer
or
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
act_layer
=
act_layer
or
nn
.
GELU
self
.
num_classes
=
num_classes
self
.
global_pool
=
global_pool
self
.
num_features
=
self
.
embed_dim
=
embed_dim
# num_features for consistency with other models
self
.
num_prefix_tokens
=
1
if
class_token
else
0
self
.
no_embed_class
=
no_embed_class
patch_embed_extra_kwargs
=
({
'fused_bias_fc'
:
fused_bias_fc
}
if
embed_layer
is
PatchEmbed
else
{})
self
.
patch_embed
=
embed_layer
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
bias
=
not
pre_norm
,
# disable bias if pre-norm is used (e.g. CLIP)
**
patch_embed_extra_kwargs
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
if
class_token
else
None
embed_len
=
num_patches
if
no_embed_class
else
num_patches
+
self
.
num_prefix_tokens
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
randn
(
1
,
embed_len
,
embed_dim
)
*
.
02
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)]
# stochastic depth decay rule
# We change the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and
# the main branch (output of LN). The model definition is unchanged, but the mapping of the
# nn.LayerNorm weights are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
# self.norm_0 is the first layer norm in the model, while self.norm
# (in the pretrained weight) is the final layer norm.
self
.
norm_0
=
norm_layer
(
embed_dim
)
self
.
blocks
=
nn
.
ModuleList
([
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path
=
dpr
[
i
],
norm_layer
=
norm_layer
,
act_layer
=
act_layer
,
use_flash_attn
=
use_flash_attn
,
fused_bias_fc
=
fused_bias_fc
,
fused_dense_gelu_dense
=
fused_dense_gelu_dense
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
layer_idx
=
i
,
n_layer
=
depth
,
last_layer_subset
=
(
global_pool
==
'token'
)
)
for
i
in
range
(
depth
)])
# Classifier Head
self
.
head
=
nn
.
Linear
(
self
.
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
self
.
init_weights
(
weight_init
)
def
init_weights
(
self
,
mode
=
''
):
assert
mode
==
''
trunc_normal_
(
self
.
pos_embed
,
std
=
.
02
)
if
self
.
cls_token
is
not
None
:
nn
.
init
.
normal_
(
self
.
cls_token
,
std
=
1e-6
)
named_apply
(
init_weights_vit_timm
,
self
)
def
_init_weights
(
self
,
m
):
# this fn left here for compat with downstream users
init_weights_vit_timm
(
m
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'pos_embed'
,
'cls_token'
}
def
_pos_embed
(
self
,
x
):
if
self
.
no_embed_class
:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x
=
x
+
self
.
pos_embed
if
self
.
cls_token
is
not
None
:
x
=
torch
.
cat
((
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
),
x
),
dim
=
1
)
else
:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if
self
.
cls_token
is
not
None
:
x
=
torch
.
cat
((
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
),
x
),
dim
=
1
)
x
=
x
+
self
.
pos_embed
return
self
.
pos_drop
(
x
)
def
forward_features
(
self
,
x
,
all_tokens
=
True
):
"""
If all_tokens==False and self.global_pool == 'token', we only return the features for the
cls token.
"""
x
=
self
.
patch_embed
(
x
)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
residual
=
self
.
_pos_embed
(
x
).
float
()
hidden_states
=
self
.
norm_0
(
residual
.
to
(
dtype
=
self
.
norm_0
.
weight
.
dtype
))
if
self
.
global_pool
!=
'token'
or
all_tokens
:
for
block
in
self
.
blocks
:
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
else
:
for
block
in
self
.
blocks
[:
-
1
]:
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
# For the last layer, we only want the 1st token of the output. So we do cross-attention
# where the query is the 1st token and the key/value is the whole sequence.
hidden_states_1st
=
rearrange
(
hidden_states
[:,
0
],
'b d -> b 1 d'
)
residual_1st
=
rearrange
(
residual
[:,
0
],
'b d -> b 1 d'
)
hidden_states
,
_
=
self
.
blocks
[
-
1
](
hidden_states_1st
,
residual_1st
,
mixer_kwargs
=
{
'x_kv'
:
hidden_states
})
return
hidden_states
def
forward_head
(
self
,
x
,
pre_logits
:
bool
=
False
):
if
self
.
global_pool
:
x
=
x
[:,
self
.
num_prefix_tokens
:].
mean
(
dim
=
1
)
if
self
.
global_pool
==
'avg'
else
x
[:,
0
]
return
x
if
pre_logits
else
self
.
head
(
x
)
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
,
all_tokens
=
False
)
x
=
self
.
forward_head
(
x
)
return
x
def
init_weights_vit_timm
(
module
:
nn
.
Module
,
name
:
str
=
''
):
""" ViT weight initialization, original timm impl (for reproducibility) """
if
isinstance
(
module
,
nn
.
Linear
):
trunc_normal_
(
module
.
weight
,
std
=
.
02
)
if
module
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
module
.
bias
)
elif
hasattr
(
module
,
'init_weights'
):
module
.
init_weights
()
def
vit_base_patch16_224
(
pretrained
=
False
,
**
kwargs
):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
assert
not
pretrained
model_kwargs
=
dict
(
patch_size
=
16
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
**
kwargs
)
model
=
VisionTransformer
(
**
model_kwargs
)
return
model
flash_attn/ops/triton/k_activations.py
0 → 100644
View file @
2e33fc8e
# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
enum
import
Enum
from
typing
import
Optional
import
triton
import
triton.language
as
tl
_sqrt2pi
=
math
.
sqrt
(
2.0
/
math
.
pi
)
_sqrt1_2
=
math
.
sqrt
(
1.0
/
2
)
_gaussian_pdf_normalization
=
1.0
/
math
.
sqrt
(
2
*
math
.
pi
)
class
Activation
(
str
,
Enum
):
SquaredReLU
=
"squared_relu"
GeLU
=
"gelu"
GeLUApprox
=
"gelu_approx"
LeakyReLU
=
"leaky_relu"
ReLU
=
"relu"
def
get_triton_activation_kernel
(
activation
:
Optional
[
Activation
]):
return
(
{
Activation
.
ReLU
:
relu
,
Activation
.
LeakyReLU
:
leaky_relu
,
Activation
.
GeLU
:
gelu
,
Activation
.
GeLUApprox
:
gelu_approx
,
Activation
.
SquaredReLU
:
squared_relu
,
}[
activation
]
if
activation
else
None
)
def
get_triton_activation_bwd_kernel
(
activation
:
Optional
[
Activation
]):
return
(
{
Activation
.
ReLU
:
relu_grad
,
Activation
.
LeakyReLU
:
leaky_relu_grad
,
Activation
.
GeLU
:
gelu_grad
,
Activation
.
GeLUApprox
:
gelu_approx_grad
,
Activation
.
SquaredReLU
:
squared_relu_grad
,
}[
activation
]
if
activation
else
None
)
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
return
2
*
tl
.
sigmoid
(
2
*
x
)
-
1
@
triton
.
jit
def
cosh
(
x
):
exp_x
=
tl
.
exp
(
x
)
return
(
exp_x
+
1.0
/
exp_x
)
*
0.5
# a Triton implementation of the most used activations
# See for instance http://arxiv.org/abs/1606.08415 for an overview
# ReLU
@
triton
.
jit
def
relu
(
x
):
"""
ReLU_ activation function
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
"""
zero
=
0.0
return
tl
.
where
(
x
>=
0
,
x
,
zero
.
to
(
x
.
dtype
))
@
triton
.
jit
def
relu_grad
(
x
):
# ReLU is different from other activations
# in that it does not require the input to retrospectively compute its gradient
# here the input is the downstream gradient, and we return the upstream gradient directly
zero
=
0.0
one
=
1.0
return
tl
.
where
(
x
>=
0
,
one
.
to
(
x
.
dtype
),
zero
.
to
(
x
.
dtype
))
@
triton
.
jit
def
squared_relu
(
x
):
"""
Squared ReLU activation, as proposed in the Primer_ paper.
.. _Primer: https://arxiv.org/abs/2109.08668
"""
x_
=
relu
(
x
)
return
(
x_
*
x_
).
to
(
x
.
dtype
)
@
triton
.
jit
def
squared_relu_grad
(
x
):
return
tl
.
where
(
x
>=
0
,
2.0
*
x
,
0.0
)
# Leaky ReLU
@
triton
.
jit
def
leaky_relu
(
x
):
"""
LeakyReLU_ activation
.. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
"""
scale
=
0.01
+
0.0
scale
=
scale
.
to
(
x
.
dtype
)
return
tl
.
where
(
x
>=
0
,
x
,
scale
*
x
)
@
triton
.
jit
def
leaky_relu_grad
(
x
):
min_grad
=
0.01
max_grad
=
1
min_grad
=
min_grad
.
to
(
x
.
dtype
)
max_grad
=
max_grad
.
to
(
x
.
dtype
)
return
tl
.
where
(
x
>=
0
,
max_grad
,
min_grad
)
@
triton
.
jit
def
gelu
(
x
):
"""Gaussian Error Linear Unit (GELU)"""
return
x
*
0.5
*
(
1.0
+
tl
.
libdevice
.
erf
(
x
*
_sqrt1_2
))
@
triton
.
jit
def
gelu_grad
(
x
):
cdf
=
0.5
*
(
1.0
+
tl
.
libdevice
.
erf
(
x
*
_sqrt1_2
))
pdf
=
tl
.
exp
(
-
0.5
*
x
*
x
)
*
_gaussian_pdf_normalization
return
cdf
+
x
*
pdf
@
triton
.
jit
def
gelu_approx
(
x
):
"""
GeLU_ activation - Gaussian error linear unit, with tanh approximation
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
"""
return
0.5
*
x
*
(
1.0
+
tanh
(
_sqrt2pi
*
x
*
(
1.0
+
0.044715
*
x
*
x
)))
@
triton
.
jit
def
gelu_approx_grad
(
x
):
# CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out
=
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
return
0.5
*
x
*
(
(
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
)
)
+
0.5
*
(
1
+
tanh_out
)
flash_attn/ops/triton/linear.py
0 → 100644
View file @
2e33fc8e
# Adapted on https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
from
typing
import
Optional
import
torch
import
triton
import
triton.language
as
tl
from
torch.autograd.function
import
FunctionCtx
from
torch.cuda.amp
import
custom_fwd
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
flash_attn.ops.triton.k_activations
import
gelu
,
gelu_grad
,
gelu_approx
,
gelu_approx_grad
,
squared_relu
,
squared_relu_grad
# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
def
init_to_zero
(
name
):
return
lambda
nargs
:
nargs
[
name
].
zero_
()
def
get_configs_io_bound
():
configs
=
[]
for
num_stages
in
[
2
,
3
,
4
,
5
,
6
]:
for
block_m
in
[
16
,
32
]:
for
block_k
in
[
32
,
64
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
num_warps
=
2
if
block_n
<=
64
else
4
configs
.
append
(
triton
.
Config
(
{
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"BLOCK_K"
:
block_k
,
"SPLIT_K"
:
1
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
)
)
# split_k not used
# for split_k in [2, 4, 8, 16]:
# configs.append(triton.Config(
# {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
# num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return
configs
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
]
+
get_configs_io_bound
(),
key
=
[
"CACHE_KEY_M"
,
"CACHE_KEY_N"
,
"CACHE_KEY_K"
],
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
},
)
@
triton
.
heuristics
(
{
"EVEN_K"
:
lambda
args
:
args
[
"K"
]
%
(
args
[
"BLOCK_K"
]
*
args
[
"SPLIT_K"
])
==
0
,
}
)
@
triton
.
jit
def
kernel_fwd
(
C
,
# Pointers to matrices
ACT_INPUT
,
A
,
B
,
bias
,
# Matrix dimensions
M
,
N
,
K
,
CACHE_KEY_M
,
CACHE_KEY_N
,
CACHE_KEY_K
,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_cm
,
# stride_cn, # Assume that stride_cn == 1
stride_am
,
stride_ak
,
stride_bn
,
stride_bk
,
# Meta-parameters
BLOCK_M
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
# split k not used, not performant with activation, kept because early_config_prune is expecting it
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
A_ROWMAJOR
:
tl
.
constexpr
,
B_COLMAJOR
:
tl
.
constexpr
,
BIAS
:
tl
.
constexpr
,
SAVE_ACT_INPUT
:
tl
.
constexpr
,
ACTIVATION
:
tl
.
constexpr
,
):
"""
Kernel for computing Out = activation(A x W + C)
- Input has shape (M, K)
- Weight has shape (K, N)
- Bias has shape (N,)
- Output has shape (M, N)
- ActInputs (optional) has shape (M, N)
'ActInputs' optionally saves the A x W + C intermediate for backward computations
This kernel will consolidate over K
"""
pid
=
tl
.
program_id
(
axis
=
0
)
grid_m
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
grid_n
=
(
N
+
BLOCK_N
-
1
)
//
BLOCK_N
# re-order program ID for better L2 performance
width
=
GROUP_M
*
grid_n
group_id
=
pid
//
width
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
pid_n
=
(
pid
%
width
)
//
(
group_size
)
# now compute the block that each program will go through
# rm (resp. rn) denotes a range of indices
# for rows (resp. col) of C
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# trick to avoid masking on M and N axis
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rm
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rn
%
N
,
BLOCK_N
),
BLOCK_N
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
)
if
A_ROWMAJOR
:
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:])
else
:
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:]
*
stride_ak
)
if
B_COLMAJOR
:
B
=
B
+
(
rk
[:,
None
]
+
rbn
[
None
,
:]
*
stride_bn
)
else
:
B
=
B
+
(
rk
[:,
None
]
*
stride_bk
+
rbn
[
None
,
:]
*
stride_bn
)
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
K
,
0
,
-
BLOCK_K
):
if
EVEN_K
:
a
=
tl
.
load
(
A
)
b
=
tl
.
load
(
B
)
else
:
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k
,
other
=
0.0
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k
,
other
=
0.0
)
acc
+=
tl
.
dot
(
a
,
b
)
if
A_ROWMAJOR
:
A
+=
BLOCK_K
else
:
A
+=
BLOCK_K
*
stride_ak
if
B_COLMAJOR
:
B
+=
BLOCK_K
else
:
B
+=
BLOCK_K
*
stride_bk
# Putting bias after the matmul (instead of before) is faster, idk why
if
BIAS
:
bias
=
tl
.
load
(
bias
+
rn
,
mask
=
rn
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
acc
+=
bias
[
None
,
:]
# optional: save the activation inputs
if
SAVE_ACT_INPUT
:
# act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn
act_in_ptrs
=
ACT_INPUT
+
ram
[:,
None
]
*
stride_cm
+
rbn
[
None
,
:]
tl
.
store
(
act_in_ptrs
,
acc
)
# optional: fused activation (while the data is in shared memory)
if
ACTIVATION
==
"gelu"
:
acc
=
gelu
(
acc
)
elif
ACTIVATION
==
"gelu_approx"
:
acc
=
gelu_approx
(
acc
)
elif
ACTIVATION
==
"squared_relu"
:
acc
=
squared_relu
(
acc
)
# rematerialize rm and rn to save registers
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# write back result
# C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
C
=
C
+
rm
[:,
None
]
*
stride_cm
+
rn
[
None
,
:]
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
tl
.
store
(
C
,
acc
)
def
triton_linear_act
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
'id'
,
save_act_input
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Compute e = activation(x @ weight.T + bias).
This wrapper kicks the `kernel_fwd` Triton kernel
:param x: input tensor
:param weight: weight matrix
:param bias: an optional bias tensor
:param activation: Activation name. Needs to be a Triton kernel.
:param act_input: an optional tensor to save the activation inputs (for backward)
:return: result tensor
"""
# if torch.is_autocast_enabled():
# dtype = torch.get_autocast_gpu_dtype()
# x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
assert
activation
in
[
'id'
,
'gelu'
,
'gelu_approx'
,
'squared_relu'
]
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
x_reshaped
=
x
.
reshape
(
batch_dim
,
n
)
if
x_reshaped
.
stride
(
0
)
>
1
and
x_reshaped
.
stride
(
1
)
>
1
:
x_reshaped
=
x_reshaped
.
contiguous
()
if
weight
.
stride
(
0
)
>
1
and
weight
.
stride
(
1
)
>
1
:
weight
=
weight
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
assert
x
.
dtype
==
weight
.
dtype
,
f
"Input and weight must have the same dtype, got
{
x
.
dtype
}
and
{
weight
.
dtype
}
"
if
bias
is
not
None
:
assert
x
.
dtype
==
bias
.
dtype
,
f
"Input and bias must have the same dtype, got
{
x
.
dtype
}
and
{
bias
.
dtype
}
"
assert
x_reshaped
.
shape
[
1
]
==
weight
.
shape
[
1
],
f
"Incompatible dimensions:
{
x_reshaped
.
shape
}
-
{
weight
.
shape
}
"
assert
bias
is
None
or
bias
.
shape
[
0
]
==
weight
.
shape
[
0
],
"Incompatible dimensions in between weight and bias"
M
,
K
=
x_reshaped
.
shape
N
,
K
=
weight
.
shape
output
=
torch
.
empty
((
M
,
N
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
act_input
=
torch
.
empty_like
(
output
)
if
save_act_input
else
None
# 1D launch kernel where each block gets its own program.
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_N"
]),)
# noqa
kernel_fwd
[
grid
](
output
,
act_input
,
x_reshaped
,
weight
,
# data ptrs
bias
if
bias
is
not
None
else
x
,
# auto skip bias if not present
M
,
# shapes
N
,
K
,
M
//
32
,
# key for triton cache (limit number of compilations)
N
//
32
,
K
//
32
,
stride_cm
=
output
.
stride
(
0
),
# strides
# stride_cn=output.stride(1),
stride_am
=
x_reshaped
.
stride
(
0
),
stride_ak
=
x_reshaped
.
stride
(
1
),
stride_bk
=
weight
.
stride
(
1
),
stride_bn
=
weight
.
stride
(
0
),
BIAS
=
bias
is
not
None
,
# optional fused bias
SAVE_ACT_INPUT
=
save_act_input
,
# optional save activation inputs
ACTIVATION
=
activation
,
# optional fused activation
A_ROWMAJOR
=
x_reshaped
.
stride
(
1
)
==
1
,
B_COLMAJOR
=
weight
.
stride
(
1
)
==
1
,
GROUP_M
=
8
,
# speed optimization: group the programs
)
if
not
save_act_input
:
return
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
])
else
:
return
(
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
]),
act_input
.
reshape
(
*
batch_shape
,
act_input
.
shape
[
-
1
]))
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
]
+
get_configs_io_bound
(),
key
=
[
"CACHE_KEY_M"
,
"CACHE_KEY_N"
,
"CACHE_KEY_K"
],
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
},
)
@
triton
.
heuristics
(
{
"EVEN_K"
:
lambda
args
:
args
[
"K"
]
%
(
args
[
"BLOCK_K"
]
*
args
[
"SPLIT_K"
])
==
0
,
}
)
@
triton
.
jit
def
kernel_bwd
(
C
,
# Pointers to matrices
ACT_INPUT
,
A
,
B
,
# Matrix dimensions
M
,
N
,
K
,
CACHE_KEY_M
,
CACHE_KEY_N
,
CACHE_KEY_K
,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_cm
,
# stride_cn, # Assume that stride_cn == 1
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
# Meta-parameters
BLOCK_M
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
# split k not used, not performant with activation, kept because early_config_prune is expecting it
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
ACTIVATION
:
tl
.
constexpr
,
):
"""
Kernel for computing Out = activation(A x W + C)
- Input has shape (M, K)
- Weight has shape (K, N)
- Output has shape (M, N)
- ActInputs (optional) has shape (M, N)
'ActInputs' optionally saves the A x W + C intermediate for backward computations
This kernel will consolidate over K
"""
pid
=
tl
.
program_id
(
axis
=
0
)
grid_m
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
grid_n
=
(
N
+
BLOCK_N
-
1
)
//
BLOCK_N
# re-order program ID for better L2 performance
width
=
GROUP_M
*
grid_n
group_id
=
pid
//
width
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
pid_n
=
(
pid
%
width
)
//
(
group_size
)
# now compute the block that each program will go through
# rm (resp. rn) denotes a range of indices
# for rows (resp. col) of C
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# trick to avoid masking on M and N axis
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rm
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rn
%
N
,
BLOCK_N
),
BLOCK_N
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
)
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:]
*
stride_ak
)
B
=
B
+
(
rk
[:,
None
]
*
stride_bk
+
rbn
[
None
,
:]
*
stride_bn
)
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
K
,
0
,
-
BLOCK_K
):
if
EVEN_K
:
a
=
tl
.
load
(
A
)
b
=
tl
.
load
(
B
)
else
:
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k
,
other
=
0.0
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k
,
other
=
0.0
)
acc
+=
tl
.
dot
(
a
,
b
)
A
+=
BLOCK_K
*
stride_ak
B
+=
BLOCK_K
*
stride_bk
# optional: fused activation (while the data is in shared memory)
if
ACTIVATION
!=
'id'
:
act_in_ptrs
=
ACT_INPUT
+
ram
[:,
None
]
*
stride_cm
+
rbn
[
None
,
:]
act_input
=
tl
.
load
(
act_in_ptrs
).
to
(
acc
.
dtype
)
if
ACTIVATION
==
"gelu"
:
acc
*=
gelu_grad
(
act_input
)
elif
ACTIVATION
==
"gelu_approx"
:
acc
*=
gelu_approx_grad
(
act_input
)
elif
ACTIVATION
==
"squared_relu"
:
acc
*=
squared_relu_grad
(
act_input
)
# rematerialize rm and rn to save registers
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# write back result
C
=
C
+
rm
[:,
None
]
*
stride_cm
+
rn
[
None
,
:]
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
tl
.
store
(
C
,
acc
,
mask
=
mask
)
def
triton_dgrad_act
(
grad_output
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
activation
:
str
=
'id'
,
act_input
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
Compute e = activation(grad_output @ weight + bias).
This wrapper kicks the `kernel_fwd` Triton kernel
:param grad_output: input tensor
:param weight: weight matrix
:param activation: Activation name. Needs to be a Triton kernel.
:param act_input: an optional tensor to save the activation inputs (for backward)
:return: result tensor
"""
assert
activation
in
[
'id'
,
'gelu'
,
'gelu_approx'
,
'squared_relu'
]
batch_shape
,
n
=
grad_output
.
shape
[:
-
1
],
grad_output
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
grad_output_reshaped
=
grad_output
.
reshape
(
batch_dim
,
n
)
if
grad_output_reshaped
.
stride
(
0
)
>
1
and
grad_output_reshaped
.
stride
(
1
)
>
1
:
grad_output_reshaped
=
grad_output_reshaped
.
contiguous
()
if
weight
.
stride
(
0
)
>
1
and
weight
.
stride
(
1
)
>
1
:
weight
=
weight
.
contiguous
()
assert
grad_output
.
dtype
==
weight
.
dtype
,
f
"grad_output and weight must have the same dtype, got
{
grad_output
.
dtype
}
and
{
weight
.
dtype
}
"
assert
grad_output_reshaped
.
shape
[
1
]
==
weight
.
shape
[
0
],
f
"Incompatible dimensions:
{
grad_output_reshaped
.
shape
}
-
{
weight
.
shape
}
"
if
activation
!=
'id'
:
assert
act_input
is
not
None
,
f
'act_input is required for activation
{
activation
}
'
# M, N, K in bwd are different from M, N, K in fwd
M
,
K
=
grad_output_reshaped
.
shape
K
,
N
=
weight
.
shape
grad_input
=
torch
.
empty
((
M
,
N
),
device
=
grad_output
.
device
,
dtype
=
grad_output
.
dtype
)
# 1D launch kernel where each block gets its own program.
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_N"
]),)
# noqa
kernel_bwd
[
grid
](
grad_input
,
act_input
,
grad_output_reshaped
,
weight
,
# data ptrs
M
,
# shapes
N
,
K
,
M
//
32
,
# key for triton cache (limit number of compilations)
N
//
32
,
K
//
32
,
stride_cm
=
grad_input
.
stride
(
0
),
# strides
# stride_cn=grad_input.stride(1),
stride_am
=
grad_output_reshaped
.
stride
(
0
),
stride_ak
=
grad_output_reshaped
.
stride
(
1
),
stride_bk
=
weight
.
stride
(
0
),
stride_bn
=
weight
.
stride
(
1
),
ACTIVATION
=
activation
,
# optional fused activation
GROUP_M
=
8
,
# speed optimization: group the programs
)
return
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
flash_attn/ops/triton/mlp.py
0 → 100644
View file @
2e33fc8e
# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared
# to naive implementation.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
import
fused_dense_lib
as
fused_dense_cuda
from
flash_attn.ops.triton.linear
import
triton_linear_act
,
triton_dgrad_act
@
torch
.
jit
.
script
def
sqrelu_fwd
(
x
):
r
=
F
.
relu
(
x
)
return
(
r
*
r
).
to
(
dtype
=
x
.
dtype
)
@
torch
.
jit
.
script
def
sqrelu_bwd
(
g
,
x
):
return
(
2.0
*
g
*
F
.
relu
(
x
)).
to
(
dtype
=
x
.
dtype
)
class
FusedDenseSqreluDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
checkpoint_lvl
=
0
):
"""checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute act_input and gelu_out in the bwd
"""
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight1
,
bias1
,
weight2
,
bias2
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight1
,
bias1
,
weight2
,
bias2
]]
is_bf16
=
x
.
dtype
==
torch
.
bfloat16
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
x
=
x
.
contiguous
()
weight1
=
weight1
.
contiguous
()
bias1
=
bias1
.
contiguous
()
weight2
=
weight2
.
contiguous
()
bias2
=
bias2
.
contiguous
()
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
if
is_bf16
:
act_input
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
)
output1
=
sqrelu_fwd
(
act_input
)
else
:
save_act_input
=
checkpoint_lvl
!=
2
result
=
triton_linear_act
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
activation
=
'squared_relu'
,
save_act_input
=
save_act_input
)
if
save_act_input
:
output1
,
act_input
=
result
else
:
output1
=
result
output2
=
fused_dense_cuda
.
linear_bias_forward
(
output1
,
weight2
,
bias2
)
ctx
.
checkpoint_lvl
=
checkpoint_lvl
if
checkpoint_lvl
==
0
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
,
act_input
,
output1
)
elif
checkpoint_lvl
==
1
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
,
act_input
)
elif
checkpoint_lvl
==
2
:
ctx
.
save_for_backward
(
x
,
weight1
,
bias1
,
weight2
)
return
output2
.
reshape
(
*
batch_shape
,
output2
.
shape
[
-
1
])
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
grad_output
=
grad_output
.
contiguous
()
checkpoint_lvl
=
ctx
.
checkpoint_lvl
x
,
weight1
,
bias1
,
weight2
,
*
rest
=
ctx
.
saved_tensors
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
is_bf16
=
x
.
dtype
==
torch
.
bfloat16
if
checkpoint_lvl
==
0
:
act_input
,
output1
=
rest
elif
checkpoint_lvl
==
1
:
act_input
,
=
rest
output1
=
sqrelu_fwd
(
act_input
)
elif
checkpoint_lvl
==
2
:
if
is_bf16
:
act_input
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
)
output1
=
sqrelu_fwd
(
act_input
)
else
:
output1
,
act_input
=
triton_linear_act
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
activation
=
'squared_relu'
,
save_act_input
=
True
)
if
is_bf16
:
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
)
grad_output1
=
grad_output
@
weight2
grad_act_input
=
sqrelu_bwd
(
grad_output1
,
act_input
)
grad_input
,
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_backward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
grad_act_input
)
else
:
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
)
grad_act_input
=
triton_dgrad_act
(
grad_output
,
weight2
,
activation
=
'squared_relu'
,
act_input
=
act_input
)
grad_input
,
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_backward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
grad_act_input
)
return
grad_input
.
reshape_as
(
x
),
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
fused_dense_sqrelu_dense_function
=
FusedDenseSqreluDenseFunc
.
apply
class
FusedDenseSqreluDense
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
bias
=
True
,
checkpoint_lvl
=
0
,
device
=
None
,
dtype
=
None
):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
assert
bias
==
True
,
"DenseSqreluDense module without bias is currently not supported"
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
assert
x
.
is_cuda
return
fused_dense_sqrelu_dense_function
(
x
,
self
.
fc1
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
weight
,
self
.
fc2
.
bias
,
self
.
checkpoint_lvl
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment