Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
f1a73d07
Commit
f1a73d07
authored
Aug 18, 2023
by
Tri Dao
Browse files
Run isort and black on python files
parent
cbb4cf5f
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1403 additions
and
459 deletions
+1403
-459
flash_attn/modules/embedding.py
flash_attn/modules/embedding.py
+86
-53
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+6
-2
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+97
-33
flash_attn/ops/activations.py
flash_attn/ops/activations.py
+11
-6
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+4
-3
flash_attn/ops/layer_norm.py
flash_attn/ops/layer_norm.py
+535
-110
flash_attn/ops/rms_norm.py
flash_attn/ops/rms_norm.py
+113
-28
flash_attn/ops/triton/k_activations.py
flash_attn/ops/triton/k_activations.py
+4
-4
flash_attn/ops/triton/linear.py
flash_attn/ops/triton/linear.py
+173
-56
flash_attn/ops/triton/mlp.py
flash_attn/ops/triton/mlp.py
+41
-23
flash_attn/utils/benchmark.py
flash_attn/utils/benchmark.py
+155
-58
flash_attn/utils/distributed.py
flash_attn/utils/distributed.py
+18
-13
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+141
-60
flash_attn/utils/pretrained.py
flash_attn/utils/pretrained.py
+19
-10
No files found.
flash_attn/modules/embedding.py
View file @
f1a73d07
...
...
@@ -2,42 +2,52 @@
import
torch
import
torch.nn
as
nn
from
torch
import
Tensor
from
einops
import
rearrange
from
torch
import
Tensor
from
flash_attn.utils.distributed
import
reduce_scatter
,
all_reduce
from
flash_attn.utils.distributed
import
all_reduce
,
reduce_scatter
class
GPT2Embeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
padding_idx
=
None
,
word_embed_proj_dim
=
None
,
device
=
None
,
dtype
=
None
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
padding_idx
=
None
,
word_embed_proj_dim
=
None
,
device
=
None
,
dtype
=
None
,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
the project up to embed_dim
If max_position_embeddings <= 0, there's no position embeddings
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
the project up to embed_dim
"""
factory_kwargs
=
{
'
device
'
:
device
,
'
dtype
'
:
dtype
}
factory_kwargs
=
{
"
device
"
:
device
,
"
dtype
"
:
dtype
}
super
().
__init__
()
if
word_embed_proj_dim
is
None
:
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
project_in
=
None
else
:
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
word_embed_proj_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
project_in
=
nn
.
Linear
(
word_embed_proj_dim
,
embed_dim
,
bias
=
False
,
**
factory_kwargs
)
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
word_embed_proj_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
project_in
=
nn
.
Linear
(
word_embed_proj_dim
,
embed_dim
,
bias
=
False
,
**
factory_kwargs
)
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
,
**
factory_kwargs
)
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
,
**
factory_kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size
,
seqlen
=
input_ids
.
shape
embeddings
=
self
.
word_embeddings
(
input_ids
)
...
...
@@ -52,31 +62,39 @@ class GPT2Embeddings(nn.Module):
class
BertEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
type_vocab_size
,
padding_idx
=
None
,
device
=
None
,
dtype
=
None
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
type_vocab_size
,
padding_idx
=
None
,
device
=
None
,
dtype
=
None
,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings
If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings
"""
factory_kwargs
=
{
'
device
'
:
device
,
'
dtype
'
:
dtype
}
factory_kwargs
=
{
"
device
"
:
device
,
"
dtype
"
:
dtype
}
super
().
__init__
()
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
,
**
factory_kwargs
)
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
,
**
factory_kwargs
)
if
self
.
type_vocab_size
>
0
:
self
.
token_type_embeddings
=
nn
.
Embedding
(
type_vocab_size
,
embed_dim
,
**
factory_kwargs
)
self
.
token_type_embeddings
=
nn
.
Embedding
(
type_vocab_size
,
embed_dim
,
**
factory_kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
token_type_ids: (batch, seqlen)
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
token_type_ids: (batch, seqlen)
"""
batch_size
,
seqlen
=
input_ids
.
shape
embeddings
=
self
.
word_embeddings
(
input_ids
)
...
...
@@ -94,16 +112,17 @@ class BertEmbeddings(nn.Module):
class
VocabParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
*
args
,
process_group
=
None
,
padding_idx
=
None
,
**
kwargs
):
self
.
process_group
=
process_group
if
process_group
is
not
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
num_embeddings
%
world_size
!=
0
:
raise
ValueError
(
f
'num_embeddings (
{
num_embeddings
}
) must be divisible by '
f
'world_size (
{
world_size
}
)'
)
raise
ValueError
(
f
"num_embeddings (
{
num_embeddings
}
) must be divisible by "
f
"world_size (
{
world_size
}
)"
)
if
world_size
>
1
and
padding_idx
is
not
None
:
raise
RuntimeError
(
'
ParallelEmbedding does not support padding_idx
'
)
raise
RuntimeError
(
"
ParallelEmbedding does not support padding_idx
"
)
else
:
world_size
=
1
super
().
__init__
(
num_embeddings
//
world_size
,
*
args
,
padding_idx
=
padding_idx
,
**
kwargs
)
...
...
@@ -125,33 +144,45 @@ class VocabParallelEmbedding(nn.Embedding):
class
ColumnParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
*
args
,
process_group
=
None
,
**
kwargs
):
self
.
process_group
=
process_group
if
process_group
is
not
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
embedding_dim
%
world_size
!=
0
:
raise
ValueError
(
f
'embedding_dim (
{
embedding_dim
}
) must be divisible by '
f
'world_size (
{
world_size
}
)'
)
raise
ValueError
(
f
"embedding_dim (
{
embedding_dim
}
) must be divisible by "
f
"world_size (
{
world_size
}
)"
)
else
:
world_size
=
1
super
().
__init__
(
num_embeddings
,
embedding_dim
//
world_size
,
*
args
,
**
kwargs
)
class
ParallelGPT2Embeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
process_group
,
padding_idx
=
None
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
process_group
,
padding_idx
=
None
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If max_position_embeddings <= 0, there's no position embeddings
"""
factory_kwargs
=
{
'
device
'
:
device
,
'
dtype
'
:
dtype
}
factory_kwargs
=
{
"
device
"
:
device
,
"
dtype
"
:
dtype
}
super
().
__init__
()
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
word_embeddings
=
VocabParallelEmbedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
process_group
=
process_group
,
**
factory_kwargs
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
process_group
=
process_group
,
**
factory_kwargs
,
)
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
max_position_embeddings
>
0
:
...
...
@@ -161,8 +192,8 @@ class ParallelGPT2Embeddings(nn.Module):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
combine_batch_seqlen_dim
=
False
):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size
,
seqlen
=
input_ids
.
shape
world_size
=
torch
.
distributed
.
get_world_size
(
self
.
process_group
)
...
...
@@ -176,8 +207,10 @@ class ParallelGPT2Embeddings(nn.Module):
else
:
partition_dim
=
self
.
position_embeddings
.
embedding_dim
rank
=
torch
.
distributed
.
get_rank
(
self
.
process_group
)
embeddings
[...,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
]
+=
position_embeddings
embeddings
[
...,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
]
+=
position_embeddings
if
combine_batch_seqlen_dim
:
embeddings
=
rearrange
(
embeddings
,
'
b s d -> (b s) d
'
)
embeddings
=
rearrange
(
embeddings
,
"
b s d -> (b s) d
"
)
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
return
embeddings
if
world_size
<=
1
else
reduce_fn
(
embeddings
,
self
.
process_group
)
flash_attn/modules/mha.py
View file @
f1a73d07
...
...
@@ -732,8 +732,12 @@ class ParallelMHA(nn.Module):
self
.
num_heads
%
self
.
num_heads_kv
==
0
),
"num_heads must be divisible by num_heads_kv"
self
.
num_heads_per_rank
=
get_dim_for_local_rank
(
self
.
num_heads
,
self
.
world_size
,
self
.
local_rank
)
self
.
num_heads_kv_per_rank
=
get_dim_for_local_rank
(
self
.
num_heads
,
self
.
world_size
,
self
.
local_rank
)
self
.
num_heads_per_rank
=
get_dim_for_local_rank
(
self
.
num_heads
,
self
.
world_size
,
self
.
local_rank
)
self
.
num_heads_kv_per_rank
=
get_dim_for_local_rank
(
self
.
num_heads
,
self
.
world_size
,
self
.
local_rank
)
self
.
head_dim
=
self
.
embed_dim
//
num_heads
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
...
...
flash_attn/modules/mlp.py
View file @
f1a73d07
...
...
@@ -17,10 +17,19 @@ except ImportError:
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
gelu
,
bias1
=
True
,
bias2
=
True
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
gelu
,
bias1
=
True
,
bias2
=
True
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
hidden_features
if
hidden_features
is
not
None
else
in_features
*
4
...
...
@@ -37,21 +46,42 @@ class Mlp(nn.Module):
class
ParallelMLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
gelu
,
process_group
:
ProcessGroup
=
None
,
sequence_parallel
=
True
,
bias1
=
True
,
bias2
=
True
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
gelu
,
process_group
:
ProcessGroup
=
None
,
sequence_parallel
=
True
,
bias1
=
True
,
bias2
=
True
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
assert
ColumnParallelLinear
is
not
None
,
"Need to install fused_dense"
assert
RowParallelLinear
is
not
None
,
"Need to install fused_dense"
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
hidden_features
if
hidden_features
is
not
None
else
in_features
*
4
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
process_group
,
bias
=
bias1
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
process_group
,
bias
=
bias1
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
self
.
activation
=
activation
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
...
...
@@ -61,15 +91,25 @@ class ParallelMLP(nn.Module):
class
GatedMlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
sigmoid
,
bias1
=
True
,
bias2
=
True
,
multiple_of
=
256
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
sigmoid
,
bias1
=
True
,
bias2
=
True
,
multiple_of
=
256
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
(
hidden_features
if
hidden_features
is
not
None
else
int
(
8
*
in_features
/
3
))
hidden_features
=
(
hidden_features
if
hidden_features
is
not
None
else
int
(
8
*
in_features
/
3
)
)
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
2
*
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
...
...
@@ -88,24 +128,48 @@ class GatedMlp(nn.Module):
class
ParallelGatedMlp
(
nn
.
Module
):
""" Parallel GatedMlp """
def
__init__
(
self
,
in_features
,
process_group
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
sigmoid
,
bias1
=
True
,
bias2
=
True
,
multiple_of
=
256
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
"""Parallel GatedMlp"""
def
__init__
(
self
,
in_features
,
process_group
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
sigmoid
,
bias1
=
True
,
bias2
=
True
,
multiple_of
=
256
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
(
hidden_features
if
hidden_features
is
not
None
else
int
(
8
*
in_features
/
3
))
hidden_features
=
(
hidden_features
if
hidden_features
is
not
None
else
int
(
8
*
in_features
/
3
)
)
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
2
*
hidden_features
,
process_group
,
bias
=
bias1
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
raise
ImportError
(
"fused_dense is not installed"
)
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
2
*
hidden_features
,
process_group
,
bias
=
bias1
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
self
.
activation
=
activation
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
...
...
flash_attn/ops/activations.py
View file @
f1a73d07
...
...
@@ -5,7 +5,6 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
...
...
@@ -18,17 +17,19 @@ def bias_gelu(y, bias):
x
=
bias
+
y
return
(
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))).
to
(
dtype
=
y
.
dtype
)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
torch
.
jit
.
script
def
bias_gelu_back
(
g
,
y
,
bias
):
"""Assume that y has shape (B, D) and bias has shape (D)
"""
"""Assume that y has shape (B, D) and bias has shape (D)"""
x
=
bias
+
y
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
grad_y
=
ff
*
g
return
grad_y
.
to
(
dtype
=
y
.
dtype
),
grad_y
.
sum
(
dim
=
(
0
),
dtype
=
bias
.
dtype
)
...
...
@@ -56,6 +57,7 @@ bias_gelu_impl = GeLUFunction.apply
def
gelu_fwd
(
x
):
return
(
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))).
to
(
dtype
=
x
.
dtype
)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
...
...
@@ -63,7 +65,9 @@ def gelu_fwd(x):
def
gelu_bwd
(
g
,
x
):
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
return
(
ff
*
g
).
to
(
dtype
=
x
.
dtype
)
...
...
@@ -76,10 +80,11 @@ class FastGeLUFunction(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
=
ctx
.
saved_tensors
(
input
,
)
=
ctx
.
saved_tensors
tmp
=
gelu_bwd
(
grad_output
,
input
)
return
tmp
fast_gelu_impl
=
FastGeLUFunction
.
apply
...
...
flash_attn/ops/fused_dense.py
View file @
f1a73d07
...
...
@@ -10,6 +10,10 @@ import fused_dense_lib as fused_dense_cuda
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.distributed
import
ProcessGroup
from
flash_attn.ops.activations
import
gelu_bwd
,
relu_bwd
,
sqrelu_bwd
,
sqrelu_fwd
from
flash_attn.utils.distributed
import
(
all_gather_raw
,
...
...
@@ -18,9 +22,6 @@ from flash_attn.utils.distributed import (
reduce_scatter
,
reduce_scatter_raw
,
)
from
torch
import
Tensor
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.distributed
import
ProcessGroup
class
FusedDenseFunc
(
torch
.
autograd
.
Function
):
...
...
flash_attn/ops/layer_norm.py
View file @
f1a73d07
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import
dropout_layer_norm
import
torch
from
torch.nn
import
init
import
dropout_layer_norm
def
maybe_align
(
x
,
alignment_in_bytes
=
16
):
"""Assume that x already has last dim divisible by alignment_in_bytes
"""
"""Assume that x already has last dim divisible by alignment_in_bytes"""
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
return
x
if
x
.
data_ptr
()
%
alignment_in_bytes
==
0
else
x
.
clone
()
def
_dropout_add_layer_norm_forward
(
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous and aligned to 16 bytes
"""
def
_dropout_add_layer_norm_forward
(
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size
=
gamma
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
residualmat
,
gamma
,
beta
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
epsilon
,
1.0
,
0
,
None
,
residual_in_fp32
,
is_rms_norm
x0mat
,
residualmat
,
gamma
,
beta
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
epsilon
,
1.0
,
0
,
None
,
residual_in_fp32
,
is_rms_norm
,
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return
zmat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask
,
mu
,
rsigma
def
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous and aligned to 16 bytes
def
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
...
...
@@ -46,10 +79,25 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
if
x0
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
if
colscale
is
not
None
:
assert
x0
is
not
None
,
'
x0 is required to compute the gradient of colscale
'
assert
x0
is
not
None
,
"
x0 is required to compute the gradient of colscale
"
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
1.0
,
0
,
has_residual
,
is_rms_norm
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
1.0
,
0
,
has_residual
,
is_rms_norm
,
)
# dresidualmat is None if not has_residual
if
colscale
is
None
:
...
...
@@ -59,29 +107,68 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
return
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
dcolscale
def
_dropout_add_layer_norm_subset_forward
(
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous and aligned to 16 bytes
"""
def
_dropout_add_layer_norm_subset_forward
(
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size
=
gamma
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
x0_subset
=
x0_subset
.
view
(
-
1
)
if
x0_subset
is
not
None
else
None
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
residualmat
,
gamma
,
beta
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
None
,
residual_in_fp32
,
is_rms_norm
x0mat
,
residualmat
,
gamma
,
beta
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
None
,
residual_in_fp32
,
is_rms_norm
,
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return
zmat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask
,
mu
,
rsigma
def
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous and aligned to 16 bytes
def
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
...
...
@@ -94,10 +181,25 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
x0_subset
=
x0_subset
.
view
(
-
1
)
if
x0_subset
is
not
None
else
None
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
if
colscale
is
not
None
:
assert
x0
is
not
None
,
'
x0 is required to compute the gradient of colscale
'
assert
x0
is
not
None
,
"
x0 is required to compute the gradient of colscale
"
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
,
is_rms_norm
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
,
is_rms_norm
,
)
# dresidualmat is None if not has_residual
if
colscale
is
None
:
...
...
@@ -108,18 +210,44 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
def
_dropout_add_layer_norm_parallel_residual_forward
(
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
""" Assume that arguments are contiguous and aligned to 16 bytes
"""
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size
=
gamma0
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
x1mat
=
x1
.
view
((
-
1
,
hidden_size
))
if
x1
is
not
None
else
None
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
z0mat
,
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_parallel_residual_fwd
(
x0mat
,
x1mat
,
residualmat
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
None
,
residual_in_fp32
,
is_rms_norm
(
z0mat
,
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
,
)
=
dropout_layer_norm
.
dropout_add_ln_parallel_residual_fwd
(
x0mat
,
x1mat
,
residualmat
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
None
,
residual_in_fp32
,
is_rms_norm
,
)
# dmask0 and dmask1 are None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
...
...
@@ -127,10 +255,22 @@ def _dropout_add_layer_norm_parallel_residual_forward(
def
_dropout_add_layer_norm_parallel_residual_backward
(
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
=
False
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
=
False
,
):
"""
Assume that arguments are contiguous and aligned to 16 bytes
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
"""
...
...
@@ -139,9 +279,30 @@ def _dropout_add_layer_norm_parallel_residual_backward(
dz0mat
=
dz0
.
view
(
xmat
.
shape
)
dz1mat
=
dz1
.
view
(
xmat
.
shape
)
if
dz1
is
not
None
else
None
dxmat
=
dx
.
view
(
xmat
.
shape
)
if
dx
is
not
None
else
None
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_parallel_residual_bwd
(
dz0mat
,
dz1mat
,
dxmat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
(
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
,
*
rest
,
)
=
dropout_layer_norm
.
dropout_add_ln_parallel_residual_bwd
(
dz0mat
,
dz1mat
,
dxmat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
,
)
# dresidualmat is None if not has_residual
return
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
...
...
@@ -149,8 +310,21 @@ def _dropout_add_layer_norm_parallel_residual_backward(
class
DropoutAddLayerNormFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
,
):
x0
=
maybe_align
(
x0
.
contiguous
(),
16
)
residual
=
maybe_align
(
residual
.
contiguous
(),
16
)
if
residual
is
not
None
else
None
gamma
=
maybe_align
(
gamma
.
contiguous
(),
16
)
...
...
@@ -158,26 +332,43 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
rowscale
=
maybe_align
(
rowscale
.
contiguous
(),
16
)
if
rowscale
is
not
None
else
None
colscale
=
maybe_align
(
colscale
.
contiguous
(),
16
)
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_forward
(
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
,
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
x0_saved
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
)
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
x0_saved
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
)
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
has_residual
=
residual
is
not
None
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_beta
=
beta
is
not
None
if
not
return_dmask
:
return
(
zmat
.
view
(
x0
.
shape
)
if
not
prenorm
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
)))
return
(
zmat
.
view
(
x0
.
shape
)
if
not
prenorm
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
))
)
else
:
dmask
=
(
dmask
.
view
(
x0
.
shape
)
if
dropout_p
>
0.
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
))
dmask
=
(
dmask
.
view
(
x0
.
shape
)
if
dropout_p
>
0.0
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
)
)
ctx
.
mark_non_differentiable
(
dmask
)
return
((
zmat
.
view
(
x0
.
shape
),
dmask
)
if
not
prenorm
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
),
dmask
))
return
(
(
zmat
.
view
(
x0
.
shape
),
dmask
)
if
not
prenorm
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
),
dmask
)
)
@
staticmethod
def
backward
(
ctx
,
dz
,
*
args
):
...
...
@@ -189,35 +380,85 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
dropout_p
=
ctx
.
dropout_p
has_residual
=
ctx
.
has_residual
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
,
ctx
.
is_rms_norm
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
,
ctx
.
is_rms_norm
,
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
return
(
dx0
,
dresidual
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
)
return
(
dx0
,
dresidual
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
,
)
class
DropoutAddLayerNormSubsetFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
,
):
x0
=
maybe_align
(
x0
.
contiguous
(),
16
)
residual
=
maybe_align
(
residual
.
contiguous
(),
16
)
if
residual
is
not
None
else
None
gamma
=
maybe_align
(
gamma
.
contiguous
(),
16
)
beta
=
maybe_align
(
beta
.
contiguous
(),
16
)
if
beta
is
not
None
else
None
colscale
=
maybe_align
(
colscale
.
contiguous
(),
16
)
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_subset_forward
(
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
is_rms_norm
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
is_rms_norm
,
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
x_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
ctx
.
save_for_backward
(
xmat
.
view
(
x_shape
),
x0_saved
,
dmask
,
gamma
,
mu
,
rsigma
,
colscale
,
x0_subset
,
out_subset
)
ctx
.
save_for_backward
(
xmat
.
view
(
x_shape
),
x0_saved
,
dmask
,
gamma
,
mu
,
rsigma
,
colscale
,
x0_subset
,
out_subset
)
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
rowscale_const
=
rowscale_const
...
...
@@ -227,14 +468,16 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
ctx
.
has_beta
=
beta
is
not
None
z_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
if
not
return_dmask
:
return
(
zmat
.
view
(
z_shape
)
if
not
prenorm
else
(
zmat
.
view
(
z_shape
),
xmat
.
view
(
x0
.
shape
)))
return
zmat
.
view
(
z_shape
)
if
not
prenorm
else
(
zmat
.
view
(
z_shape
),
xmat
.
view
(
x0
.
shape
))
else
:
z
=
zmat
.
view
(
z_shape
)
dmask
=
(
dmask
.
view
(
x0
.
shape
)
if
dropout_p
>
0.
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
))
dmask
=
(
dmask
.
view
(
x0
.
shape
)
if
dropout_p
>
0.0
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
)
)
ctx
.
mark_non_differentiable
(
dmask
)
return
(
(
z
,
dmask
)
if
not
prenorm
else
(
z
,
xmat
.
view
(
x_shape
),
dmask
)
)
return
(
z
,
dmask
)
if
not
prenorm
else
(
z
,
xmat
.
view
(
x_shape
),
dmask
)
@
staticmethod
def
backward
(
ctx
,
dz
,
*
args
):
...
...
@@ -246,20 +489,63 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
dropout_p
=
ctx
.
dropout_p
has_residual
=
ctx
.
has_residual
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
ctx
.
rowscale_const
,
ctx
.
x0_numrows
,
has_residual
,
ctx
.
is_rms_norm
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
ctx
.
rowscale_const
,
ctx
.
x0_numrows
,
has_residual
,
ctx
.
is_rms_norm
,
)
dx0
=
dx0mat
.
view
(
-
1
,
*
x
.
shape
[
1
:])
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
return
(
dx0
,
dresidual
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
return
(
dx0
,
dresidual
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
class
DropoutAddLayerNormParallelResidualFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
def
forward
(
ctx
,
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
,
):
x0
=
maybe_align
(
x0
.
contiguous
(),
16
)
x1
=
maybe_align
(
x1
.
contiguous
(),
16
)
if
x1
is
not
None
else
None
residual
=
maybe_align
(
residual
.
contiguous
(),
16
)
if
residual
is
not
None
else
None
...
...
@@ -267,9 +553,26 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
beta0
=
maybe_align
(
beta0
.
contiguous
(),
16
)
if
beta0
is
not
None
else
None
gamma1
=
maybe_align
(
gamma1
.
contiguous
(),
16
)
if
gamma1
is
not
None
else
None
beta1
=
maybe_align
(
beta1
.
contiguous
(),
16
)
if
beta1
is
not
None
else
None
z0mat
,
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
=
_dropout_add_layer_norm_parallel_residual_forward
(
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
(
z0mat
,
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
,
)
=
_dropout_add_layer_norm_parallel_residual_forward
(
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
,
)
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
dmask0
,
dmask1
,
gamma0
,
gamma1
,
mu
,
rsigma
)
ctx
.
prenorm
=
prenorm
...
...
@@ -282,13 +585,21 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
if
not
return_dmask
:
return
z
if
not
prenorm
else
(
*
z
,
xmat
.
view
(
x0
.
shape
))
else
:
dmask0
=
(
dmask0
.
view
(
x0
.
shape
)
if
dropout_p
>
0.
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
))
dmask1
=
(
dmask1
.
view
(
x0
.
shape
)
if
dropout_p
>
0.
and
x1
is
not
None
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
))
dmask0
=
(
dmask0
.
view
(
x0
.
shape
)
if
dropout_p
>
0.0
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
)
)
dmask1
=
(
dmask1
.
view
(
x0
.
shape
)
if
dropout_p
>
0.0
and
x1
is
not
None
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
)
)
ctx
.
mark_non_differentiable
(
dmask0
)
ctx
.
mark_non_differentiable
(
dmask1
)
return
(
*
z
,
dmask0
,
dmask1
)
if
not
prenorm
else
(
*
z
,
xmat
.
view
(
x0
.
shape
),
dmask0
,
dmask1
)
return
(
(
*
z
,
dmask0
,
dmask1
)
if
not
prenorm
else
(
*
z
,
xmat
.
view
(
x0
.
shape
),
dmask0
,
dmask1
)
)
@
staticmethod
def
backward
(
ctx
,
dz0
,
dz1
,
*
args
):
...
...
@@ -299,63 +610,170 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
dropout_p
=
ctx
.
dropout_p
has_x1
=
ctx
.
has_x1
has_residual
=
ctx
.
has_residual
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
=
_dropout_add_layer_norm_parallel_residual_backward
(
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
ctx
.
is_rms_norm
(
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
,
)
=
_dropout_add_layer_norm_parallel_residual_backward
(
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
ctx
.
is_rms_norm
,
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
return
(
dx0
,
dx1
,
dresidual
,
dgamma0
,
dbeta0
if
ctx
.
has_beta
else
None
,
dgamma1
,
dbeta1
if
ctx
.
has_beta
else
None
,
None
,
None
,
None
,
None
,
None
,
None
)
return
(
dx0
,
dx1
,
dresidual
,
dgamma0
,
dbeta0
if
ctx
.
has_beta
else
None
,
dgamma1
,
dbeta1
if
ctx
.
has_beta
else
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
def
layer_norm
(
x
,
weight
,
bias
,
epsilon
):
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
bias
,
None
,
None
,
0.0
,
epsilon
,
False
)
def
dropout_add_layer_norm
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
def
dropout_add_layer_norm
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormFn
.
apply
(
x0
,
residual
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
x0
,
residual
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
,
)
def
dropout_add_layer_norm_subset
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
def
dropout_add_layer_norm_subset
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormSubsetFn
.
apply
(
x0
,
residual
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
x0
,
residual
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
,
)
def
dropout_add_layer_norm_parallel_residual
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormParallelResidualFn
.
apply
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
,
)
class
DropoutAddLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
prenorm
=
prenorm
self
.
p
=
p
...
...
@@ -370,6 +788,13 @@ class DropoutAddLayerNorm(torch.nn.Module):
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
x0
,
residual
=
None
):
return
dropout_add_layer_norm
(
x0
,
residual
,
self
.
weight
,
self
.
bias
,
self
.
p
if
self
.
training
else
0.0
,
self
.
eps
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
)
return
dropout_add_layer_norm
(
x0
,
residual
,
self
.
weight
,
self
.
bias
,
self
.
p
if
self
.
training
else
0.0
,
self
.
eps
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
,
)
flash_attn/ops/rms_norm.py
View file @
f1a73d07
...
...
@@ -4,60 +4,130 @@
import
torch
from
torch.nn
import
init
from
flash_attn.ops.layer_norm
import
DropoutAddLayerNormFn
,
DropoutAddLayerNormSubsetFn
from
flash_attn.ops.layer_norm
import
DropoutAddLayerNormParallelResidualFn
from
flash_attn.ops.layer_norm
import
(
DropoutAddLayerNormFn
,
DropoutAddLayerNormParallelResidualFn
,
DropoutAddLayerNormSubsetFn
,
)
def
rms_norm
(
x
,
weight
,
epsilon
):
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
None
,
None
,
None
,
0.0
,
epsilon
,
False
,
False
,
True
)
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
None
,
None
,
None
,
0.0
,
epsilon
,
False
,
False
,
True
)
def
dropout_add_rms_norm
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
def
dropout_add_rms_norm
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormFn
.
apply
(
x0
,
residual
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
x0
,
residual
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
,
)
def
dropout_add_rms_norm_subset
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
def
dropout_add_rms_norm_subset
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormSubsetFn
.
apply
(
x0
,
residual
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
x0
,
residual
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
,
)
def
dropout_add_rms_norm_parallel_residual
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormParallelResidualFn
.
apply
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
,
)
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'
device
'
:
device
,
'
dtype
'
:
dtype
}
factory_kwargs
=
{
"
device
"
:
device
,
"
dtype
"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
'
bias
'
,
None
)
self
.
register_parameter
(
"
bias
"
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
...
...
@@ -68,22 +138,37 @@ class RMSNorm(torch.nn.Module):
class
DropoutAddRMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
prenorm
=
prenorm
self
.
p
=
p
self
.
eps
=
eps
self
.
residual_in_fp32
=
residual_in_fp32
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
'
bias
'
,
None
)
self
.
register_parameter
(
"
bias
"
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x0
,
residual
=
None
):
return
dropout_add_rms_norm
(
x0
,
residual
,
self
.
weight
,
None
,
self
.
p
if
self
.
training
else
0.0
,
self
.
eps
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
)
return
dropout_add_rms_norm
(
x0
,
residual
,
self
.
weight
,
None
,
self
.
p
if
self
.
training
else
0.0
,
self
.
eps
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
,
)
flash_attn/ops/triton/k_activations.py
View file @
f1a73d07
...
...
@@ -11,7 +11,6 @@ from typing import Optional
import
triton
import
triton.language
as
tl
_sqrt2pi
=
math
.
sqrt
(
2.0
/
math
.
pi
)
_sqrt1_2
=
math
.
sqrt
(
1.0
/
2
)
_gaussian_pdf_normalization
=
1.0
/
math
.
sqrt
(
2
*
math
.
pi
)
...
...
@@ -142,6 +141,7 @@ def gelu_grad(x):
pdf
=
tl
.
exp
(
-
0.5
*
x
*
x
)
*
_gaussian_pdf_normalization
return
cdf
+
x
*
pdf
@
triton
.
jit
def
gelu_approx
(
x
):
"""
...
...
@@ -157,6 +157,6 @@ def gelu_approx_grad(x):
# CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out
=
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
return
0.5
*
x
*
(
(
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
)
)
+
0.5
*
(
1
+
tanh_out
)
return
0.5
*
x
*
(
(
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
flash_attn/ops/triton/linear.py
View file @
f1a73d07
...
...
@@ -9,8 +9,14 @@ from torch.autograd.function import FunctionCtx
from
torch.cuda.amp
import
custom_fwd
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
flash_attn.ops.triton.k_activations
import
gelu
,
gelu_grad
,
gelu_approx
,
gelu_approx_grad
,
squared_relu
,
squared_relu_grad
from
flash_attn.ops.triton.k_activations
import
(
gelu
,
gelu_approx
,
gelu_approx_grad
,
gelu_grad
,
squared_relu
,
squared_relu_grad
,
)
# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
...
...
@@ -28,7 +34,12 @@ def get_configs_io_bound():
num_warps
=
2
if
block_n
<=
64
else
4
configs
.
append
(
triton
.
Config
(
{
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"BLOCK_K"
:
block_k
,
"SPLIT_K"
:
1
},
{
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"BLOCK_K"
:
block_k
,
"SPLIT_K"
:
1
,
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
)
...
...
@@ -43,29 +54,75 @@ def get_configs_io_bound():
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
]
+
get_configs_io_bound
(),
key
=
[
"CACHE_KEY_M"
,
"CACHE_KEY_N"
,
"CACHE_KEY_K"
],
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
},
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
,
},
)
@
triton
.
heuristics
(
{
...
...
@@ -204,7 +261,7 @@ def triton_linear_act(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
'
id
'
,
activation
:
str
=
"
id
"
,
save_act_input
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
...
...
@@ -221,7 +278,7 @@ def triton_linear_act(
# dtype = torch.get_autocast_gpu_dtype()
# x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
assert
activation
in
[
'
id
'
,
'
gelu
'
,
'
gelu_approx
'
,
'
squared_relu
'
]
assert
activation
in
[
"
id
"
,
"
gelu
"
,
"
gelu_approx
"
,
"
squared_relu
"
]
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
...
...
@@ -233,12 +290,20 @@ def triton_linear_act(
weight
=
weight
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
assert
x
.
dtype
==
weight
.
dtype
,
f
"Input and weight must have the same dtype, got
{
x
.
dtype
}
and
{
weight
.
dtype
}
"
assert
(
x
.
dtype
==
weight
.
dtype
),
f
"Input and weight must have the same dtype, got
{
x
.
dtype
}
and
{
weight
.
dtype
}
"
if
bias
is
not
None
:
assert
x
.
dtype
==
bias
.
dtype
,
f
"Input and bias must have the same dtype, got
{
x
.
dtype
}
and
{
bias
.
dtype
}
"
assert
x_reshaped
.
shape
[
1
]
==
weight
.
shape
[
1
],
f
"Incompatible dimensions:
{
x_reshaped
.
shape
}
-
{
weight
.
shape
}
"
assert
(
x
.
dtype
==
bias
.
dtype
),
f
"Input and bias must have the same dtype, got
{
x
.
dtype
}
and
{
bias
.
dtype
}
"
assert
(
x_reshaped
.
shape
[
1
]
==
weight
.
shape
[
1
]
),
f
"Incompatible dimensions:
{
x_reshaped
.
shape
}
-
{
weight
.
shape
}
"
assert
bias
is
None
or
bias
.
shape
[
0
]
==
weight
.
shape
[
0
],
"Incompatible dimensions in between weight and bias"
assert
(
bias
is
None
or
bias
.
shape
[
0
]
==
weight
.
shape
[
0
]
),
"Incompatible dimensions in between weight and bias"
M
,
K
=
x_reshaped
.
shape
N
,
K
=
weight
.
shape
...
...
@@ -278,35 +343,83 @@ def triton_linear_act(
if
not
save_act_input
:
return
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
])
else
:
return
(
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
]),
act_input
.
reshape
(
*
batch_shape
,
act_input
.
shape
[
-
1
]))
return
(
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
]),
act_input
.
reshape
(
*
batch_shape
,
act_input
.
shape
[
-
1
]),
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
]
+
get_configs_io_bound
(),
key
=
[
"CACHE_KEY_M"
,
"CACHE_KEY_N"
,
"CACHE_KEY_K"
],
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
},
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
,
},
)
@
triton
.
heuristics
(
{
...
...
@@ -395,7 +508,7 @@ def kernel_bwd(
B
+=
BLOCK_K
*
stride_bk
# optional: fused activation (while the data is in shared memory)
if
ACTIVATION
!=
'
id
'
:
if
ACTIVATION
!=
"
id
"
:
act_in_ptrs
=
ACT_INPUT
+
ram
[:,
None
]
*
stride_cm
+
rbn
[
None
,
:]
act_input
=
tl
.
load
(
act_in_ptrs
).
to
(
acc
.
dtype
)
if
ACTIVATION
==
"gelu"
:
...
...
@@ -418,7 +531,7 @@ def kernel_bwd(
def
triton_dgrad_act
(
grad_output
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
activation
:
str
=
'
id
'
,
activation
:
str
=
"
id
"
,
act_input
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
...
...
@@ -430,7 +543,7 @@ def triton_dgrad_act(
:param act_input: an optional tensor to save the activation inputs (for backward)
:return: result tensor
"""
assert
activation
in
[
'
id
'
,
'
gelu
'
,
'
gelu_approx
'
,
'
squared_relu
'
]
assert
activation
in
[
"
id
"
,
"
gelu
"
,
"
gelu_approx
"
,
"
squared_relu
"
]
batch_shape
,
n
=
grad_output
.
shape
[:
-
1
],
grad_output
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
...
...
@@ -441,10 +554,14 @@ def triton_dgrad_act(
if
weight
.
stride
(
0
)
>
1
and
weight
.
stride
(
1
)
>
1
:
weight
=
weight
.
contiguous
()
assert
grad_output
.
dtype
==
weight
.
dtype
,
f
"grad_output and weight must have the same dtype, got
{
grad_output
.
dtype
}
and
{
weight
.
dtype
}
"
assert
grad_output_reshaped
.
shape
[
1
]
==
weight
.
shape
[
0
],
f
"Incompatible dimensions:
{
grad_output_reshaped
.
shape
}
-
{
weight
.
shape
}
"
if
activation
!=
'id'
:
assert
act_input
is
not
None
,
f
'act_input is required for activation
{
activation
}
'
assert
(
grad_output
.
dtype
==
weight
.
dtype
),
f
"grad_output and weight must have the same dtype, got
{
grad_output
.
dtype
}
and
{
weight
.
dtype
}
"
assert
(
grad_output_reshaped
.
shape
[
1
]
==
weight
.
shape
[
0
]
),
f
"Incompatible dimensions:
{
grad_output_reshaped
.
shape
}
-
{
weight
.
shape
}
"
if
activation
!=
"id"
:
assert
act_input
is
not
None
,
f
"act_input is required for activation
{
activation
}
"
# M, N, K in bwd are different from M, N, K in fwd
M
,
K
=
grad_output_reshaped
.
shape
...
...
flash_attn/ops/triton/mlp.py
View file @
f1a73d07
# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared
# to naive implementation.
import
fused_dense_lib
as
fused_dense_cuda
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
import
fused_dense_lib
as
fused_dense_cuda
from
flash_attn.ops.triton.linear
import
triton_linear_act
,
triton_dgrad_act
from
flash_attn.ops.activations
import
sqrelu_fwd
,
sqrelu_bwd
from
flash_attn.ops.activations
import
sqrelu_bwd
,
sqrelu_fwd
from
flash_attn.ops.triton.linear
import
triton_dgrad_act
,
triton_linear_act
class
FusedDenseSqreluDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
checkpoint_lvl
=
0
):
...
...
@@ -23,8 +21,9 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
"""
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight1
,
bias1
,
weight2
,
bias2
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight1
,
bias1
,
weight2
,
bias2
]]
x
,
weight1
,
bias1
,
weight2
,
bias2
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight1
,
bias1
,
weight2
,
bias2
]
]
is_bf16
=
x
.
dtype
==
torch
.
bfloat16
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
x
=
x
.
contiguous
()
...
...
@@ -35,13 +34,18 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
if
is_bf16
:
act_input
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
)
act_input
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
)
output1
=
sqrelu_fwd
(
act_input
)
else
:
save_act_input
=
checkpoint_lvl
!=
2
result
=
triton_linear_act
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
activation
=
'squared_relu'
,
save_act_input
=
save_act_input
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
activation
=
"squared_relu"
,
save_act_input
=
save_act_input
,
)
if
save_act_input
:
output1
,
act_input
=
result
...
...
@@ -69,16 +73,21 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
if
checkpoint_lvl
==
0
:
act_input
,
output1
=
rest
elif
checkpoint_lvl
==
1
:
act_input
,
=
rest
(
act_input
,
)
=
rest
output1
=
sqrelu_fwd
(
act_input
)
elif
checkpoint_lvl
==
2
:
if
is_bf16
:
act_input
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
)
act_input
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
)
output1
=
sqrelu_fwd
(
act_input
)
else
:
output1
,
act_input
=
triton_linear_act
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
activation
=
'squared_relu'
,
save_act_input
=
True
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
activation
=
"squared_relu"
,
save_act_input
=
True
,
)
if
is_bf16
:
...
...
@@ -92,8 +101,9 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
else
:
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
)
grad_act_input
=
triton_dgrad_act
(
grad_output
,
weight2
,
activation
=
'squared_relu'
,
act_input
=
act_input
)
grad_act_input
=
triton_dgrad_act
(
grad_output
,
weight2
,
activation
=
"squared_relu"
,
act_input
=
act_input
)
grad_input
,
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_backward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
grad_act_input
)
...
...
@@ -104,9 +114,17 @@ fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply
class
FusedDenseSqreluDense
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
bias1
=
True
,
bias2
=
True
,
checkpoint_lvl
=
0
,
device
=
None
,
dtype
=
None
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
bias1
=
True
,
bias2
=
True
,
checkpoint_lvl
=
0
,
device
=
None
,
dtype
=
None
,
):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
...
...
@@ -114,7 +132,7 @@ class FusedDenseSqreluDense(nn.Module):
2: recompute gelu_in and gelu_out in the bwd
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
factory_kwargs
=
{
'
device
'
:
device
,
'
dtype
'
:
dtype
}
factory_kwargs
=
{
"
device
"
:
device
,
"
dtype
"
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
*
4
...
...
@@ -126,6 +144,6 @@ class FusedDenseSqreluDense(nn.Module):
def
forward
(
self
,
x
):
assert
x
.
is_cuda
return
fused_dense_sqrelu_dense_function
(
x
,
self
.
fc1
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
weight
,
self
.
fc2
.
bias
,
self
.
checkpoint_lvl
)
return
fused_dense_sqrelu_dense_function
(
x
,
self
.
fc1
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
weight
,
self
.
fc2
.
bias
,
self
.
checkpoint_lvl
)
flash_attn/utils/benchmark.py
View file @
f1a73d07
...
...
@@ -5,31 +5,43 @@ import torch
import
torch.utils.benchmark
as
benchmark
def
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
""" Use Pytorch Benchmark on the forward pass of an arbitrary function. """
def
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
if
verbose
:
print
(
desc
,
'- Forward pass'
)
print
(
desc
,
"- Forward pass"
)
def
amp_wrapper
(
*
inputs
,
**
kwinputs
):
with
torch
.
autocast
(
device_type
=
'
cuda
'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
"
cuda
"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
fn
(
*
inputs
,
**
kwinputs
)
t
=
benchmark
.
Timer
(
stmt
=
'
fn_amp(*inputs, **kwinputs)
'
,
globals
=
{
'
fn_amp
'
:
amp_wrapper
,
'
inputs
'
:
inputs
,
'
kwinputs
'
:
kwinputs
},
num_threads
=
torch
.
get_num_threads
(),
)
stmt
=
"
fn_amp(*inputs, **kwinputs)
"
,
globals
=
{
"
fn_amp
"
:
amp_wrapper
,
"
inputs
"
:
inputs
,
"
kwinputs
"
:
kwinputs
},
num_threads
=
torch
.
get_num_threads
(),
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
print
(
m
)
return
t
,
m
def
benchmark_backward
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
""" Use Pytorch Benchmark on the backward pass of an arbitrary function. """
def
benchmark_backward
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the backward pass of an arbitrary function."""
if
verbose
:
print
(
desc
,
'
- Backward pass
'
)
with
torch
.
autocast
(
device_type
=
'
cuda
'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
print
(
desc
,
"
- Backward pass
"
)
with
torch
.
autocast
(
device_type
=
"
cuda
"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
...
...
@@ -37,7 +49,8 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
grad
=
torch
.
randn_like
(
y
)
else
:
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
'Grad shape does not match output shape'
)
raise
RuntimeError
(
"Grad shape does not match output shape"
)
def
f
(
*
inputs
,
y
,
grad
):
# Set .grad to None to avoid extra operation of gradient accumulation
for
x
in
inputs
:
...
...
@@ -46,22 +59,31 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
y
.
backward
(
grad
,
retain_graph
=
True
)
t
=
benchmark
.
Timer
(
stmt
=
'
f(*inputs, y=y, grad=grad)
'
,
globals
=
{
'f'
:
f
,
'
inputs
'
:
inputs
,
'y'
:
y
,
'
grad
'
:
grad
},
num_threads
=
torch
.
get_num_threads
(),
)
stmt
=
"
f(*inputs, y=y, grad=grad)
"
,
globals
=
{
"f"
:
f
,
"
inputs
"
:
inputs
,
"y"
:
y
,
"
grad
"
:
grad
},
num_threads
=
torch
.
get_num_threads
(),
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
print
(
m
)
return
t
,
m
def
benchmark_combined
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
def
benchmark_combined
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
if
verbose
:
print
(
desc
,
'
- Forward + Backward pass
'
)
with
torch
.
autocast
(
device_type
=
'
cuda
'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
print
(
desc
,
"
- Forward + Backward pass
"
)
with
torch
.
autocast
(
device_type
=
"
cuda
"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
...
...
@@ -69,68 +91,142 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
grad
=
torch
.
randn_like
(
y
)
else
:
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
'Grad shape does not match output shape'
)
raise
RuntimeError
(
"Grad shape does not match output shape"
)
def
f
(
grad
,
*
inputs
,
**
kwinputs
):
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
'
cuda
'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
"
cuda
"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
y
.
backward
(
grad
,
retain_graph
=
True
)
t
=
benchmark
.
Timer
(
stmt
=
'
f(grad, *inputs, **kwinputs)
'
,
globals
=
{
'f'
:
f
,
'
fn
'
:
fn
,
'
inputs
'
:
inputs
,
'
grad
'
:
grad
,
'
kwinputs
'
:
kwinputs
},
num_threads
=
torch
.
get_num_threads
(),
)
stmt
=
"
f(grad, *inputs, **kwinputs)
"
,
globals
=
{
"f"
:
f
,
"
fn
"
:
fn
,
"
inputs
"
:
inputs
,
"
grad
"
:
grad
,
"
kwinputs
"
:
kwinputs
},
num_threads
=
torch
.
get_num_threads
(),
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
print
(
m
)
return
t
,
m
def
benchmark_fwd_bwd
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
def
benchmark_fwd_bwd
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return
(
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
)
def
benchmark_all
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
def
benchmark_all
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return
(
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
benchmark_combined
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
benchmark_combined
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
)
def
pytorch_profiler
(
fn
,
*
inputs
,
trace_filename
=
None
,
backward
=
False
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
cpu
=
False
,
verbose
=
True
,
**
kwinputs
):
""" Wrap benchmark functions in Pytorch profiler to see CUDA information. """
def
pytorch_profiler
(
fn
,
*
inputs
,
trace_filename
=
None
,
backward
=
False
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
cpu
=
False
,
verbose
=
True
,
**
kwinputs
,
):
"""Wrap benchmark functions in Pytorch profiler to see CUDA information."""
if
backward
:
with
torch
.
autocast
(
device_type
=
'
cuda
'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
"
cuda
"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
g
=
torch
.
randn_like
(
fn
(
*
inputs
,
**
kwinputs
))
for
_
in
range
(
30
):
# Warm up
for
_
in
range
(
30
):
# Warm up
if
backward
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
'
cuda
'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
"
cuda
"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
out
=
fn
(
*
inputs
,
**
kwinputs
)
# Backward should be done outside autocast
if
backward
:
out
.
backward
(
g
,
retain_graph
=
True
)
activities
=
([
torch
.
profiler
.
ProfilerActivity
.
CPU
]
if
cpu
else
[])
+
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
]
activities
=
([
torch
.
profiler
.
ProfilerActivity
.
CPU
]
if
cpu
else
[])
+
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
]
with
torch
.
profiler
.
profile
(
activities
=
activities
,
record_shapes
=
True
,
...
...
@@ -141,9 +237,10 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
'
cuda
'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
"
cuda
"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
backward
:
out
.
backward
(
g
,
retain_graph
=
True
)
if
backward
:
out
.
backward
(
g
,
retain_graph
=
True
)
if
verbose
:
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
print
(
prof
.
key_averages
().
table
(
row_limit
=
50
))
...
...
@@ -151,14 +248,14 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False
prof
.
export_chrome_trace
(
trace_filename
)
def
benchmark_memory
(
fn
,
*
inputs
,
desc
=
''
,
verbose
=
True
,
**
kwinputs
):
def
benchmark_memory
(
fn
,
*
inputs
,
desc
=
""
,
verbose
=
True
,
**
kwinputs
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_peak_memory_stats
()
torch
.
cuda
.
synchronize
()
fn
(
*
inputs
,
**
kwinputs
)
torch
.
cuda
.
synchronize
()
mem
=
torch
.
cuda
.
max_memory_allocated
()
/
((
2
**
20
)
*
1000
)
mem
=
torch
.
cuda
.
max_memory_allocated
()
/
((
2
**
20
)
*
1000
)
if
verbose
:
print
(
f
'
{
desc
}
max memory:
{
mem
}
GB
'
)
print
(
f
"
{
desc
}
max memory:
{
mem
}
GB
"
)
torch
.
cuda
.
empty_cache
()
return
mem
flash_attn/utils/distributed.py
View file @
f1a73d07
...
...
@@ -17,10 +17,12 @@ if "reduce_scatter_tensor" not in dir(torch.distributed):
# Raw operation, does not support autograd, but does support async
def
all_gather_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
output
=
torch
.
empty
(
world_size
*
input_
.
shape
[
0
],
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
all_gather_into_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
output
=
torch
.
empty
(
world_size
*
input_
.
shape
[
0
],
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
all_gather_into_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
...
...
@@ -28,11 +30,12 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool =
def
reduce_scatter_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
assert
input_
.
shape
[
0
]
%
world_size
==
0
output
=
torch
.
empty
(
input_
.
shape
[
0
]
//
world_size
,
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
reduce_scatter_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
output
=
torch
.
empty
(
input_
.
shape
[
0
]
//
world_size
,
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
reduce_scatter_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
...
...
@@ -102,8 +105,9 @@ all_reduce = AllReduceFunc.apply
def
sync_shared_params
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
# We want to iterate over parameters with _shared_params=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
pamams_shared
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
'_shared_params'
,
False
)}
pamams_shared
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
"_shared_params"
,
False
)
}
for
_
,
p
in
sorted
(
pamams_shared
.
items
()):
with
torch
.
no_grad
():
# Broadcast needs src to be global rank, not group rank
...
...
@@ -116,8 +120,9 @@ def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
def
allreduce_sequence_parallel_grad
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
'_sequence_parallel'
,
False
)}
params_seqparallel
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
"_sequence_parallel"
,
False
)
}
grads
=
[
p
.
grad
for
_
,
p
in
sorted
(
params_seqparallel
.
items
())]
if
grads
:
with
torch
.
no_grad
():
...
...
flash_attn/utils/generation.py
View file @
f1a73d07
# Copyright (c) 2023, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from
typing
import
Optional
,
Union
,
Sequence
,
Callable
import
gc
import
time
from
dataclasses
import
dataclass
,
field
from
collections
import
namedtuple
from
dataclasses
import
dataclass
,
field
from
typing
import
Callable
,
Optional
,
Sequence
,
Union
import
torch
from
torch
import
Tensor
from
torch.profiler
import
profile
,
record_function
,
ProfilerActivity
from
einops
import
rearrange
from
torch
import
Tensor
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
from
transformers.generation
import
GreedySearchDecoderOnlyOutput
,
SampleDecoderOnlyOutput
...
...
@@ -20,6 +17,7 @@ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoder
class
InferenceParams
:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_sequence_len
:
int
max_batch_size
:
int
sequence_len_offset
:
int
=
0
...
...
@@ -38,11 +36,13 @@ def modify_logits_for_top_p_filtering(logits, top_p):
# First sort and calculate cumulative sum of probabilities.
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
False
)
cumulative_probs
=
sorted_logits
.
softmax
(
dim
=-
1
).
cumsum
(
dim
=-
1
)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove
=
cumulative_probs
<=
(
1
-
top_p
)
# scatter sorted tensors to original indexing
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
1
,
sorted_indices
,
sorted_indices_to_remove
)
logits
=
logits
.
masked_fill
(
indices_to_remove
,
float
(
'-inf'
))
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
1
,
sorted_indices
,
sorted_indices_to_remove
)
logits
=
logits
.
masked_fill
(
indices_to_remove
,
float
(
"-inf"
))
def
sample
(
logits
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
):
...
...
@@ -54,7 +54,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
return
logits
.
argmax
(
dim
=-
1
)
else
:
if
top_p
>
0.0
:
assert
top_p
<=
1.0
,
'
top-p should be in (0, 1].
'
assert
top_p
<=
1.0
,
"
top-p should be in (0, 1].
"
if
top_k
>
0
:
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
logits_top
,
indices
=
torch
.
topk
(
logits
,
top_k
,
dim
=-
1
)
...
...
@@ -62,17 +62,31 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
modify_logits_for_top_p_filtering
(
logits_top
,
top_p
)
return
indices
[
torch
.
arange
(
indices
.
shape
[
0
],
device
=
indices
.
device
),
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
)
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
)
,
]
else
:
logits_top
=
logits
/
temperature
modify_logits_for_top_p_filtering
(
logits_top
,
top_p
)
return
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
)
return
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
)
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
eos_token_id
=
None
,
teacher_outputs
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
):
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
eos_token_id
=
None
,
teacher_outputs
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
,
):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
...
...
@@ -92,19 +106,24 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
teacher_output_len
=
teacher_outputs
.
shape
[
1
]
if
teacher_outputs
is
not
None
else
0
if
cg
:
assert
fused_ft_kernel
if
not
hasattr
(
model
,
'
_decoding_cache
'
):
if
not
hasattr
(
model
,
"
_decoding_cache
"
):
model
.
_decoding_cache
=
None
model
.
_decoding_cache
=
update_graph_cache
(
model
,
model
.
_decoding_cache
,
batch_size
,
seqlen_og
,
max_length
,
tensor_parallel
=
tensor_parallel
model
,
model
.
_decoding_cache
,
batch_size
,
seqlen_og
,
max_length
,
tensor_parallel
=
tensor_parallel
,
)
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
.
max_sequence_len
=
max_length
inference_params
.
max_batch_size
=
batch_size
inference_params
.
sequence_len_offset
=
0
else
:
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
fused_ft_kernel
)
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
fused_ft_kernel
)
scores
=
[]
with
torch
.
inference_mode
():
if
timing
:
...
...
@@ -123,18 +142,32 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
sequences
=
[
next_token
]
inference_params
.
sequence_len_offset
=
seqlen_og
while
True
:
position_ids
=
torch
.
full
((
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
torch
.
full
(
(
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
)
if
not
cg
:
logits
=
model
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
=
position_ids
,
inference_params
=
inference_params
,
last_token_only
=
True
).
logits
logits
=
model
(
rearrange
(
next_token
,
"b -> b 1"
),
position_ids
=
position_ids
,
inference_params
=
inference_params
,
last_token_only
=
True
,
).
logits
else
:
logits
=
model
.
_decoding_cache
.
run
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
,
inference_params
.
sequence_len_offset
)
logits
=
model
.
_decoding_cache
.
run
(
rearrange
(
next_token
,
"b -> b 1"
),
position_ids
,
inference_params
.
sequence_len_offset
,
)
if
vocab_size
is
not
None
:
logits
=
logits
[...,
:
vocab_size
]
scores
.
append
(
logits
if
not
cg
else
logits
.
clone
())
if
teacher_outputs
is
None
or
teacher_output_len
<=
inference_params
.
sequence_len_offset
+
1
:
if
(
teacher_outputs
is
None
or
teacher_output_len
<=
inference_params
.
sequence_len_offset
+
1
):
next_token
=
sample
(
logits
,
top_k
=
top_k
,
temperature
=
temperature
)
else
:
next_token
=
teacher_outputs
[:,
inference_params
.
sequence_len_offset
+
1
]
...
...
@@ -148,30 +181,45 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
print
(
f
'
Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms
'
)
print
(
f
"
Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms
"
)
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
scores
=
tuple
(
scores
)
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
scores
=
tuple
(
scores
)
)
class
GenerationMixin
:
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
raise
NotImplementedError
def
generate
(
self
,
input_ids
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
return_dict_in_generate
=
False
,
output_scores
=
False
,
**
kwargs
):
output
=
decode
(
input_ids
,
self
,
max_length
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
**
kwargs
)
def
generate
(
self
,
input_ids
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
return_dict_in_generate
=
False
,
output_scores
=
False
,
**
kwargs
,
):
output
=
decode
(
input_ids
,
self
,
max_length
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
**
kwargs
)
if
not
output_scores
:
output
.
scores
=
None
return
output
if
return_dict_in_generate
else
output
.
sequences
def
allocate_inference_cache
(
max_batch_size
,
max_seqlen
,
nheads
,
headdim
,
layers
:
Union
[
int
,
Sequence
],
device
,
dtype
=
torch
.
float16
):
def
allocate_inference_cache
(
max_batch_size
,
max_seqlen
,
nheads
,
headdim
,
layers
:
Union
[
int
,
Sequence
],
device
,
dtype
=
torch
.
float16
,
):
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
assert
headdim
%
packsize
==
0
...
...
@@ -179,9 +227,13 @@ def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers
v_cache_shape
=
(
max_batch_size
,
nheads
,
max_seqlen
,
headdim
)
if
isinstance
(
layers
,
int
):
layers
=
range
(
layers
)
return
{
i
:
(
torch
.
empty
(
k_cache_shape
,
device
=
device
,
dtype
=
dtype
),
torch
.
empty
(
v_cache_shape
,
device
=
device
,
dtype
=
dtype
))
for
i
in
layers
}
return
{
i
:
(
torch
.
empty
(
k_cache_shape
,
device
=
device
,
dtype
=
dtype
),
torch
.
empty
(
v_cache_shape
,
device
=
device
,
dtype
=
dtype
),
)
for
i
in
layers
}
def
seqlen_to_seqlen_type
(
seqlen
:
int
)
->
int
:
...
...
@@ -211,49 +263,70 @@ class DecodingCGCache:
@
torch
.
inference_mode
()
def
update_graph_cache
(
model
,
cache
,
batch_size
,
seqlen_og
,
max_seqlen
,
tensor_parallel
=
1
,
dtype
=
None
,
n_warmups
=
2
):
def
update_graph_cache
(
model
,
cache
,
batch_size
,
seqlen_og
,
max_seqlen
,
tensor_parallel
=
1
,
dtype
=
None
,
n_warmups
=
2
):
if
cache
is
None
:
cache
=
DecodingCGCache
()
param_example
=
next
(
iter
(
model
.
parameters
()))
device
=
param_example
.
device
if
dtype
is
None
:
dtype
=
param_example
.
dtype
if
((
device
,
dtype
)
!=
(
cache
.
device
,
cache
.
dtype
)
or
batch_size
>
cache
.
max_batch_size
or
max_seqlen
>
cache
.
max_seqlen
):
# Invalidate the cache
if
(
(
device
,
dtype
)
!=
(
cache
.
device
,
cache
.
dtype
)
or
batch_size
>
cache
.
max_batch_size
or
max_seqlen
>
cache
.
max_seqlen
):
# Invalidate the cache
cache
.
callables
=
{}
cache
.
mempool
=
None
cache
.
inference_params
=
None
gc
.
collect
()
cache
.
device
,
cache
.
dtype
=
device
,
dtype
cache
.
max_batch_size
,
cache
.
max_seqlen
=
batch_size
,
max_seqlen
if
hasattr
(
model
,
'
allocate_inference_cache
'
):
if
hasattr
(
model
,
"
allocate_inference_cache
"
):
inf_cache
=
model
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
)
else
:
headdim
=
getattr
(
model
.
config
,
'head_dim'
,
model
.
config
.
hidden_size
//
model
.
config
.
num_attention_heads
)
headdim
=
getattr
(
model
.
config
,
"head_dim"
,
model
.
config
.
hidden_size
//
model
.
config
.
num_attention_heads
,
)
inf_cache
=
allocate_inference_cache
(
batch_size
,
max_seqlen
,
model
.
config
.
num_attention_heads
//
tensor_parallel
,
headdim
,
model
.
config
.
num_hidden_layers
,
device
,
dtype
batch_size
,
max_seqlen
,
model
.
config
.
num_attention_heads
//
tensor_parallel
,
headdim
,
model
.
config
.
num_hidden_layers
,
device
,
dtype
,
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
cache
.
inference_params
=
InferenceParams
(
max_sequence_len
=
max_seqlen
,
max_batch_size
=
batch_size
,
sequence_len_offset
=
seqlen_og
,
key_value_memory_dict
=
inf_cache
,
fused_ft_kernel
=
True
,
lengths_per_sample
=
lengths_per_sample
max_sequence_len
=
max_seqlen
,
max_batch_size
=
batch_size
,
sequence_len_offset
=
seqlen_og
,
key_value_memory_dict
=
inf_cache
,
fused_ft_kernel
=
True
,
lengths_per_sample
=
lengths_per_sample
,
)
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
for
s_type
in
range
(
seqlen_to_seqlen_type
(
seqlen_og
),
seqlen_to_seqlen_type
(
max_seqlen
)
+
1
):
if
(
batch_size
,
s_type
)
not
in
cache
.
callables
:
max_seqlen_
=
min
(
max
(
seqlen_og
,
seqlen_type_to_max_seqlen
(
s_type
)),
max_seqlen
)
cache
.
callables
[
batch_size
,
s_type
]
=
capture_graph
(
model
,
cache
.
inference_params
,
batch_size
,
max_seqlen_
,
mempool
=
cache
.
mempool
,
n_warmups
=
n_warmups
model
,
cache
.
inference_params
,
batch_size
,
max_seqlen_
,
mempool
=
cache
.
mempool
,
n_warmups
=
n_warmups
,
)
def
dispatch
(
input_ids
,
position_ids
,
seqlen
):
batch_size
=
input_ids
.
shape
[
0
]
return
cache
.
callables
[
batch_size
,
seqlen_to_seqlen_type
(
seqlen
)](
input_ids
,
position_ids
,
seqlen
)
return
cache
.
callables
[
batch_size
,
seqlen_to_seqlen_type
(
seqlen
)](
input_ids
,
position_ids
,
seqlen
)
cache
.
run
=
dispatch
cache
.
inference_params
.
sequence_len_offset
=
0
# Reset so it's not confusing
...
...
@@ -275,8 +348,12 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
s
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
s
):
for
_
in
range
(
n_warmups
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
last_token_only
=
True
).
logits
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
last_token_only
=
True
,
).
logits
s
.
synchronize
()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think,
...
...
@@ -288,8 +365,12 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
# To allow capture, automatically sets a side stream as the current stream in the context
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
mempool
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
last_token_only
=
True
).
logits
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
last_token_only
=
True
,
).
logits
def
run
(
new_input_ids
,
new_position_ids
,
seqlen
):
inference_params
.
lengths_per_sample
[:]
=
seqlen
...
...
flash_attn/utils/pretrained.py
View file @
f1a73d07
...
...
@@ -3,13 +3,18 @@ from functools import partial
import
torch
from
safetensors.torch
import
load_file
as
safe_load_file
from
transformers.utils
import
WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
SAFE_WEIGHTS_INDEX_NAME
from
transformers.utils
import
(
SAFE_WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
,
)
from
transformers.utils.hub
import
cached_file
,
get_checkpoint_shard_files
def
state_dict_from_pretrained
(
model_name
,
device
=
None
,
dtype
=
None
):
# If not fp32, then we don't want to load directly to the GPU
mapped_device
=
'
cpu
'
if
dtype
not
in
[
torch
.
float32
,
None
]
else
device
mapped_device
=
"
cpu
"
if
dtype
not
in
[
torch
.
float32
,
None
]
else
device
is_sharded
=
False
load_safe
=
False
resolved_archive_file
=
None
...
...
@@ -20,19 +25,23 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
safe_weights_index_path
=
os
.
path
.
join
(
model_name
,
SAFE_WEIGHTS_INDEX_NAME
)
if
os
.
path
.
isfile
(
weights_path
):
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
elif
os
.
path
.
isfile
(
weights_index_path
):
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
is_sharded
=
True
elif
os
.
path
.
isfile
(
safe_weights_path
):
resolved_archive_file
=
cached_file
(
model_name
,
SAFE_WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
resolved_archive_file
=
cached_file
(
model_name
,
SAFE_WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
load_safe
=
True
elif
os
.
path
.
isfile
(
safe_weights_index_path
):
resolved_archive_file
=
cached_file
(
model_name
,
SAFE_WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
resolved_archive_file
=
cached_file
(
model_name
,
SAFE_WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
is_sharded
=
True
load_safe
=
True
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment