Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
f1a73d07
Commit
f1a73d07
authored
Aug 18, 2023
by
Tri Dao
Browse files
Run isort and black on python files
parent
cbb4cf5f
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1403 additions
and
459 deletions
+1403
-459
flash_attn/modules/embedding.py
flash_attn/modules/embedding.py
+86
-53
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+6
-2
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+97
-33
flash_attn/ops/activations.py
flash_attn/ops/activations.py
+11
-6
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+4
-3
flash_attn/ops/layer_norm.py
flash_attn/ops/layer_norm.py
+535
-110
flash_attn/ops/rms_norm.py
flash_attn/ops/rms_norm.py
+113
-28
flash_attn/ops/triton/k_activations.py
flash_attn/ops/triton/k_activations.py
+4
-4
flash_attn/ops/triton/linear.py
flash_attn/ops/triton/linear.py
+173
-56
flash_attn/ops/triton/mlp.py
flash_attn/ops/triton/mlp.py
+41
-23
flash_attn/utils/benchmark.py
flash_attn/utils/benchmark.py
+155
-58
flash_attn/utils/distributed.py
flash_attn/utils/distributed.py
+18
-13
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+141
-60
flash_attn/utils/pretrained.py
flash_attn/utils/pretrained.py
+19
-10
No files found.
flash_attn/modules/embedding.py
View file @
f1a73d07
...
@@ -2,42 +2,52 @@
...
@@ -2,42 +2,52 @@
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch
import
Tensor
from
einops
import
rearrange
from
einops
import
rearrange
from
torch
import
Tensor
from
flash_attn.utils.distributed
import
reduce_scatter
,
all_reduce
from
flash_attn.utils.distributed
import
all_reduce
,
reduce_scatter
class
GPT2Embeddings
(
nn
.
Module
):
class
GPT2Embeddings
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
padding_idx
=
None
,
self
,
word_embed_proj_dim
=
None
,
device
=
None
,
dtype
=
None
):
embed_dim
,
vocab_size
,
max_position_embeddings
,
padding_idx
=
None
,
word_embed_proj_dim
=
None
,
device
=
None
,
dtype
=
None
,
):
"""
"""
If max_position_embeddings <= 0, there's no position embeddings
If max_position_embeddings <= 0, there's no position embeddings
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
the project up to embed_dim
the project up to embed_dim
"""
"""
factory_kwargs
=
{
'
device
'
:
device
,
'
dtype
'
:
dtype
}
factory_kwargs
=
{
"
device
"
:
device
,
"
dtype
"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
if
word_embed_proj_dim
is
None
:
if
word_embed_proj_dim
is
None
:
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
self
.
word_embeddings
=
nn
.
Embedding
(
**
factory_kwargs
)
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
project_in
=
None
self
.
project_in
=
None
else
:
else
:
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
word_embed_proj_dim
,
self
.
word_embeddings
=
nn
.
Embedding
(
padding_idx
=
padding_idx
,
**
factory_kwargs
)
vocab_size
,
word_embed_proj_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
self
.
project_in
=
nn
.
Linear
(
word_embed_proj_dim
,
embed_dim
,
bias
=
False
,
)
**
factory_kwargs
)
self
.
project_in
=
nn
.
Linear
(
word_embed_proj_dim
,
embed_dim
,
bias
=
False
,
**
factory_kwargs
)
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
max_position_embeddings
>
0
:
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
,
self
.
position_embeddings
=
nn
.
Embedding
(
**
factory_kwargs
)
max_position_embeddings
,
embed_dim
,
**
factory_kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
):
"""
"""
input_ids: (batch, seqlen)
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
"""
batch_size
,
seqlen
=
input_ids
.
shape
batch_size
,
seqlen
=
input_ids
.
shape
embeddings
=
self
.
word_embeddings
(
input_ids
)
embeddings
=
self
.
word_embeddings
(
input_ids
)
...
@@ -52,31 +62,39 @@ class GPT2Embeddings(nn.Module):
...
@@ -52,31 +62,39 @@ class GPT2Embeddings(nn.Module):
class
BertEmbeddings
(
nn
.
Module
):
class
BertEmbeddings
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
type_vocab_size
,
self
,
padding_idx
=
None
,
device
=
None
,
dtype
=
None
):
embed_dim
,
vocab_size
,
max_position_embeddings
,
type_vocab_size
,
padding_idx
=
None
,
device
=
None
,
dtype
=
None
,
):
"""
"""
If max_position_embeddings <= 0, there's no position embeddings
If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings
If type_vocab_size <= 0, there's no token type embeddings
"""
"""
factory_kwargs
=
{
'
device
'
:
device
,
'
dtype
'
:
dtype
}
factory_kwargs
=
{
"
device
"
:
device
,
"
dtype
"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
self
.
word_embeddings
=
nn
.
Embedding
(
**
factory_kwargs
)
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
type_vocab_size
=
type_vocab_size
if
self
.
max_position_embeddings
>
0
:
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
,
self
.
position_embeddings
=
nn
.
Embedding
(
**
factory_kwargs
)
max_position_embeddings
,
embed_dim
,
**
factory_kwargs
)
if
self
.
type_vocab_size
>
0
:
if
self
.
type_vocab_size
>
0
:
self
.
token_type_embeddings
=
nn
.
Embedding
(
type_vocab_size
,
embed_dim
,
self
.
token_type_embeddings
=
nn
.
Embedding
(
type_vocab_size
,
embed_dim
,
**
factory_kwargs
)
**
factory_kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
):
"""
"""
input_ids: (batch, seqlen)
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
position_ids: (batch, seqlen)
token_type_ids: (batch, seqlen)
token_type_ids: (batch, seqlen)
"""
"""
batch_size
,
seqlen
=
input_ids
.
shape
batch_size
,
seqlen
=
input_ids
.
shape
embeddings
=
self
.
word_embeddings
(
input_ids
)
embeddings
=
self
.
word_embeddings
(
input_ids
)
...
@@ -94,16 +112,17 @@ class BertEmbeddings(nn.Module):
...
@@ -94,16 +112,17 @@ class BertEmbeddings(nn.Module):
class
VocabParallelEmbedding
(
nn
.
Embedding
):
class
VocabParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
*
args
,
process_group
=
None
,
padding_idx
=
None
,
**
kwargs
):
def
__init__
(
self
,
num_embeddings
,
*
args
,
process_group
=
None
,
padding_idx
=
None
,
**
kwargs
):
self
.
process_group
=
process_group
self
.
process_group
=
process_group
if
process_group
is
not
None
:
if
process_group
is
not
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
num_embeddings
%
world_size
!=
0
:
if
num_embeddings
%
world_size
!=
0
:
raise
ValueError
(
f
'num_embeddings (
{
num_embeddings
}
) must be divisible by '
raise
ValueError
(
f
'world_size (
{
world_size
}
)'
)
f
"num_embeddings (
{
num_embeddings
}
) must be divisible by "
f
"world_size (
{
world_size
}
)"
)
if
world_size
>
1
and
padding_idx
is
not
None
:
if
world_size
>
1
and
padding_idx
is
not
None
:
raise
RuntimeError
(
'
ParallelEmbedding does not support padding_idx
'
)
raise
RuntimeError
(
"
ParallelEmbedding does not support padding_idx
"
)
else
:
else
:
world_size
=
1
world_size
=
1
super
().
__init__
(
num_embeddings
//
world_size
,
*
args
,
padding_idx
=
padding_idx
,
**
kwargs
)
super
().
__init__
(
num_embeddings
//
world_size
,
*
args
,
padding_idx
=
padding_idx
,
**
kwargs
)
...
@@ -125,33 +144,45 @@ class VocabParallelEmbedding(nn.Embedding):
...
@@ -125,33 +144,45 @@ class VocabParallelEmbedding(nn.Embedding):
class
ColumnParallelEmbedding
(
nn
.
Embedding
):
class
ColumnParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
*
args
,
process_group
=
None
,
**
kwargs
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
*
args
,
process_group
=
None
,
**
kwargs
):
self
.
process_group
=
process_group
self
.
process_group
=
process_group
if
process_group
is
not
None
:
if
process_group
is
not
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
embedding_dim
%
world_size
!=
0
:
if
embedding_dim
%
world_size
!=
0
:
raise
ValueError
(
f
'embedding_dim (
{
embedding_dim
}
) must be divisible by '
raise
ValueError
(
f
'world_size (
{
world_size
}
)'
)
f
"embedding_dim (
{
embedding_dim
}
) must be divisible by "
f
"world_size (
{
world_size
}
)"
)
else
:
else
:
world_size
=
1
world_size
=
1
super
().
__init__
(
num_embeddings
,
embedding_dim
//
world_size
,
*
args
,
**
kwargs
)
super
().
__init__
(
num_embeddings
,
embedding_dim
//
world_size
,
*
args
,
**
kwargs
)
class
ParallelGPT2Embeddings
(
nn
.
Module
):
class
ParallelGPT2Embeddings
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
process_group
,
self
,
padding_idx
=
None
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
):
embed_dim
,
vocab_size
,
max_position_embeddings
,
process_group
,
padding_idx
=
None
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
,
):
"""
"""
If max_position_embeddings <= 0, there's no position embeddings
If max_position_embeddings <= 0, there's no position embeddings
"""
"""
factory_kwargs
=
{
'
device
'
:
device
,
'
dtype
'
:
dtype
}
factory_kwargs
=
{
"
device
"
:
device
,
"
dtype
"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
sequence_parallel
=
sequence_parallel
self
.
word_embeddings
=
VocabParallelEmbedding
(
self
.
word_embeddings
=
VocabParallelEmbedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
process_group
=
process_group
,
vocab_size
,
**
factory_kwargs
embed_dim
,
padding_idx
=
padding_idx
,
process_group
=
process_group
,
**
factory_kwargs
,
)
)
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
max_position_embeddings
>
0
:
if
self
.
max_position_embeddings
>
0
:
...
@@ -161,8 +192,8 @@ class ParallelGPT2Embeddings(nn.Module):
...
@@ -161,8 +192,8 @@ class ParallelGPT2Embeddings(nn.Module):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
combine_batch_seqlen_dim
=
False
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
combine_batch_seqlen_dim
=
False
):
"""
"""
input_ids: (batch, seqlen)
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
"""
batch_size
,
seqlen
=
input_ids
.
shape
batch_size
,
seqlen
=
input_ids
.
shape
world_size
=
torch
.
distributed
.
get_world_size
(
self
.
process_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
self
.
process_group
)
...
@@ -176,8 +207,10 @@ class ParallelGPT2Embeddings(nn.Module):
...
@@ -176,8 +207,10 @@ class ParallelGPT2Embeddings(nn.Module):
else
:
else
:
partition_dim
=
self
.
position_embeddings
.
embedding_dim
partition_dim
=
self
.
position_embeddings
.
embedding_dim
rank
=
torch
.
distributed
.
get_rank
(
self
.
process_group
)
rank
=
torch
.
distributed
.
get_rank
(
self
.
process_group
)
embeddings
[...,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
]
+=
position_embeddings
embeddings
[
...,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
]
+=
position_embeddings
if
combine_batch_seqlen_dim
:
if
combine_batch_seqlen_dim
:
embeddings
=
rearrange
(
embeddings
,
'
b s d -> (b s) d
'
)
embeddings
=
rearrange
(
embeddings
,
"
b s d -> (b s) d
"
)
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
return
embeddings
if
world_size
<=
1
else
reduce_fn
(
embeddings
,
self
.
process_group
)
return
embeddings
if
world_size
<=
1
else
reduce_fn
(
embeddings
,
self
.
process_group
)
flash_attn/modules/mha.py
View file @
f1a73d07
...
@@ -732,8 +732,12 @@ class ParallelMHA(nn.Module):
...
@@ -732,8 +732,12 @@ class ParallelMHA(nn.Module):
self
.
num_heads
%
self
.
num_heads_kv
==
0
self
.
num_heads
%
self
.
num_heads_kv
==
0
),
"num_heads must be divisible by num_heads_kv"
),
"num_heads must be divisible by num_heads_kv"
self
.
num_heads_per_rank
=
get_dim_for_local_rank
(
self
.
num_heads
,
self
.
world_size
,
self
.
local_rank
)
self
.
num_heads_per_rank
=
get_dim_for_local_rank
(
self
.
num_heads_kv_per_rank
=
get_dim_for_local_rank
(
self
.
num_heads
,
self
.
world_size
,
self
.
local_rank
)
self
.
num_heads
,
self
.
world_size
,
self
.
local_rank
)
self
.
num_heads_kv_per_rank
=
get_dim_for_local_rank
(
self
.
num_heads
,
self
.
world_size
,
self
.
local_rank
)
self
.
head_dim
=
self
.
embed_dim
//
num_heads
self
.
head_dim
=
self
.
embed_dim
//
num_heads
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
...
...
flash_attn/modules/mlp.py
View file @
f1a73d07
...
@@ -17,10 +17,19 @@ except ImportError:
...
@@ -17,10 +17,19 @@ except ImportError:
class
Mlp
(
nn
.
Module
):
class
Mlp
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
gelu
,
self
,
bias1
=
True
,
bias2
=
True
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
):
in_features
,
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
gelu
,
bias1
=
True
,
bias2
=
True
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
out_features
=
out_features
if
out_features
is
not
None
else
in_features
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
hidden_features
if
hidden_features
is
not
None
else
in_features
*
4
hidden_features
=
hidden_features
if
hidden_features
is
not
None
else
in_features
*
4
...
@@ -37,21 +46,42 @@ class Mlp(nn.Module):
...
@@ -37,21 +46,42 @@ class Mlp(nn.Module):
class
ParallelMLP
(
nn
.
Module
):
class
ParallelMLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
gelu
,
self
,
process_group
:
ProcessGroup
=
None
,
sequence_parallel
=
True
,
in_features
,
bias1
=
True
,
bias2
=
True
,
device
=
None
,
dtype
=
None
):
hidden_features
=
None
,
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
out_features
=
None
,
activation
=
F
.
gelu
,
process_group
:
ProcessGroup
=
None
,
sequence_parallel
=
True
,
bias1
=
True
,
bias2
=
True
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
assert
ColumnParallelLinear
is
not
None
,
"Need to install fused_dense"
assert
ColumnParallelLinear
is
not
None
,
"Need to install fused_dense"
assert
RowParallelLinear
is
not
None
,
"Need to install fused_dense"
assert
RowParallelLinear
is
not
None
,
"Need to install fused_dense"
out_features
=
out_features
if
out_features
is
not
None
else
in_features
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
hidden_features
if
hidden_features
is
not
None
else
in_features
*
4
hidden_features
=
hidden_features
if
hidden_features
is
not
None
else
in_features
*
4
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
process_group
,
bias
=
bias1
,
self
.
fc1
=
ColumnParallelLinear
(
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
in_features
,
hidden_features
,
process_group
,
bias
=
bias1
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
self
.
activation
=
activation
self
.
activation
=
activation
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
self
.
fc2
=
RowParallelLinear
(
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
y
=
self
.
fc1
(
x
)
...
@@ -61,15 +91,25 @@ class ParallelMLP(nn.Module):
...
@@ -61,15 +91,25 @@ class ParallelMLP(nn.Module):
class
GatedMlp
(
nn
.
Module
):
class
GatedMlp
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
sigmoid
,
self
,
bias1
=
True
,
bias2
=
True
,
multiple_of
=
256
,
return_residual
=
False
,
in_features
,
device
=
None
,
dtype
=
None
):
hidden_features
=
None
,
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
out_features
=
None
,
activation
=
F
.
sigmoid
,
bias1
=
True
,
bias2
=
True
,
multiple_of
=
256
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
out_features
=
out_features
if
out_features
is
not
None
else
in_features
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
(
hidden_features
if
hidden_features
is
not
None
hidden_features
=
(
else
int
(
8
*
in_features
/
3
))
hidden_features
if
hidden_features
is
not
None
else
int
(
8
*
in_features
/
3
)
)
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
self
.
return_residual
=
return_residual
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
2
*
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
fc1
=
nn
.
Linear
(
in_features
,
2
*
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
...
@@ -88,24 +128,48 @@ class GatedMlp(nn.Module):
...
@@ -88,24 +128,48 @@ class GatedMlp(nn.Module):
class
ParallelGatedMlp
(
nn
.
Module
):
class
ParallelGatedMlp
(
nn
.
Module
):
""" Parallel GatedMlp """
"""Parallel GatedMlp"""
def
__init__
(
self
,
in_features
,
process_group
,
hidden_features
=
None
,
out_features
=
None
,
def
__init__
(
activation
=
F
.
sigmoid
,
bias1
=
True
,
bias2
=
True
,
multiple_of
=
256
,
self
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
):
in_features
,
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
process_group
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
sigmoid
,
bias1
=
True
,
bias2
=
True
,
multiple_of
=
256
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
out_features
=
out_features
if
out_features
is
not
None
else
in_features
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
(
hidden_features
if
hidden_features
is
not
None
hidden_features
=
(
else
int
(
8
*
in_features
/
3
))
hidden_features
if
hidden_features
is
not
None
else
int
(
8
*
in_features
/
3
)
)
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
"fused_dense is not installed"
)
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
2
*
hidden_features
,
process_group
,
bias
=
bias1
,
self
.
fc1
=
ColumnParallelLinear
(
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
in_features
,
2
*
hidden_features
,
process_group
,
bias
=
bias1
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
self
.
activation
=
activation
self
.
activation
=
activation
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
self
.
fc2
=
RowParallelLinear
(
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
y
=
self
.
fc1
(
x
)
...
...
flash_attn/ops/activations.py
View file @
f1a73d07
...
@@ -5,7 +5,6 @@ import torch
...
@@ -5,7 +5,6 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# sqrt(2/pi) -> 0.79788456
...
@@ -18,17 +17,19 @@ def bias_gelu(y, bias):
...
@@ -18,17 +17,19 @@ def bias_gelu(y, bias):
x
=
bias
+
y
x
=
bias
+
y
return
(
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))).
to
(
dtype
=
y
.
dtype
)
return
(
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))).
to
(
dtype
=
y
.
dtype
)
# gradient of tanh approximation of gelu
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
bias_gelu_back
(
g
,
y
,
bias
):
def
bias_gelu_back
(
g
,
y
,
bias
):
"""Assume that y has shape (B, D) and bias has shape (D)
"""Assume that y has shape (B, D) and bias has shape (D)"""
"""
x
=
bias
+
y
x
=
bias
+
y
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
grad_y
=
ff
*
g
grad_y
=
ff
*
g
return
grad_y
.
to
(
dtype
=
y
.
dtype
),
grad_y
.
sum
(
dim
=
(
0
),
dtype
=
bias
.
dtype
)
return
grad_y
.
to
(
dtype
=
y
.
dtype
),
grad_y
.
sum
(
dim
=
(
0
),
dtype
=
bias
.
dtype
)
...
@@ -56,6 +57,7 @@ bias_gelu_impl = GeLUFunction.apply
...
@@ -56,6 +57,7 @@ bias_gelu_impl = GeLUFunction.apply
def
gelu_fwd
(
x
):
def
gelu_fwd
(
x
):
return
(
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))).
to
(
dtype
=
x
.
dtype
)
return
(
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))).
to
(
dtype
=
x
.
dtype
)
# gradient of tanh approximation of gelu
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
...
@@ -63,7 +65,9 @@ def gelu_fwd(x):
...
@@ -63,7 +65,9 @@ def gelu_fwd(x):
def
gelu_bwd
(
g
,
x
):
def
gelu_bwd
(
g
,
x
):
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
return
(
ff
*
g
).
to
(
dtype
=
x
.
dtype
)
return
(
ff
*
g
).
to
(
dtype
=
x
.
dtype
)
...
@@ -76,10 +80,11 @@ class FastGeLUFunction(torch.autograd.Function):
...
@@ -76,10 +80,11 @@ class FastGeLUFunction(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
input
,
=
ctx
.
saved_tensors
(
input
,
)
=
ctx
.
saved_tensors
tmp
=
gelu_bwd
(
grad_output
,
input
)
tmp
=
gelu_bwd
(
grad_output
,
input
)
return
tmp
return
tmp
fast_gelu_impl
=
FastGeLUFunction
.
apply
fast_gelu_impl
=
FastGeLUFunction
.
apply
...
...
flash_attn/ops/fused_dense.py
View file @
f1a73d07
...
@@ -10,6 +10,10 @@ import fused_dense_lib as fused_dense_cuda
...
@@ -10,6 +10,10 @@ import fused_dense_lib as fused_dense_cuda
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.distributed
import
ProcessGroup
from
flash_attn.ops.activations
import
gelu_bwd
,
relu_bwd
,
sqrelu_bwd
,
sqrelu_fwd
from
flash_attn.ops.activations
import
gelu_bwd
,
relu_bwd
,
sqrelu_bwd
,
sqrelu_fwd
from
flash_attn.utils.distributed
import
(
from
flash_attn.utils.distributed
import
(
all_gather_raw
,
all_gather_raw
,
...
@@ -18,9 +22,6 @@ from flash_attn.utils.distributed import (
...
@@ -18,9 +22,6 @@ from flash_attn.utils.distributed import (
reduce_scatter
,
reduce_scatter
,
reduce_scatter_raw
,
reduce_scatter_raw
,
)
)
from
torch
import
Tensor
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.distributed
import
ProcessGroup
class
FusedDenseFunc
(
torch
.
autograd
.
Function
):
class
FusedDenseFunc
(
torch
.
autograd
.
Function
):
...
...
flash_attn/ops/layer_norm.py
View file @
f1a73d07
# Copyright (c) 2022, Tri Dao.
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import
dropout_layer_norm
import
torch
import
torch
from
torch.nn
import
init
from
torch.nn
import
init
import
dropout_layer_norm
def
maybe_align
(
x
,
alignment_in_bytes
=
16
):
def
maybe_align
(
x
,
alignment_in_bytes
=
16
):
"""Assume that x already has last dim divisible by alignment_in_bytes
"""Assume that x already has last dim divisible by alignment_in_bytes"""
"""
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
return
x
if
x
.
data_ptr
()
%
alignment_in_bytes
==
0
else
x
.
clone
()
return
x
if
x
.
data_ptr
()
%
alignment_in_bytes
==
0
else
x
.
clone
()
def
_dropout_add_layer_norm_forward
(
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
def
_dropout_add_layer_norm_forward
(
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
x0
,
""" Assume that arguments are contiguous and aligned to 16 bytes
residual
,
"""
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size
=
gamma
.
numel
()
hidden_size
=
gamma
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
residualmat
,
gamma
,
beta
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
epsilon
,
x0mat
,
1.0
,
0
,
None
,
residual_in_fp32
,
is_rms_norm
residualmat
,
gamma
,
beta
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
epsilon
,
1.0
,
0
,
None
,
residual_in_fp32
,
is_rms_norm
,
)
)
# dmask is None if dropout_p == 0.0
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return
zmat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask
,
mu
,
rsigma
return
zmat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask
,
mu
,
rsigma
def
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
def
_dropout_add_layer_norm_backward
(
dropout_p
,
has_residual
,
is_rms_norm
=
False
):
dz
,
""" Assume that arguments are contiguous and aligned to 16 bytes
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
x0 must not be None if we have colscale.
...
@@ -46,10 +79,25 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
...
@@ -46,10 +79,25 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
if
x0
is
not
None
else
None
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
if
x0
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
if
colscale
is
not
None
:
if
colscale
is
not
None
:
assert
x0
is
not
None
,
'
x0 is required to compute the gradient of colscale
'
assert
x0
is
not
None
,
"
x0 is required to compute the gradient of colscale
"
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
None
,
None
,
dzmat
,
dropout_p
,
1.0
,
0
,
has_residual
,
is_rms_norm
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
1.0
,
0
,
has_residual
,
is_rms_norm
,
)
)
# dresidualmat is None if not has_residual
# dresidualmat is None if not has_residual
if
colscale
is
None
:
if
colscale
is
None
:
...
@@ -59,29 +107,68 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
...
@@ -59,29 +107,68 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
return
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
dcolscale
return
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
dcolscale
def
_dropout_add_layer_norm_subset_forward
(
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
def
_dropout_add_layer_norm_subset_forward
(
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
x0
,
out_numrows
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
residual
,
""" Assume that arguments are contiguous and aligned to 16 bytes
gamma
,
"""
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size
=
gamma
.
numel
()
hidden_size
=
gamma
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
x0_subset
=
x0_subset
.
view
(
-
1
)
if
x0_subset
is
not
None
else
None
x0_subset
=
x0_subset
.
view
(
-
1
)
if
x0_subset
is
not
None
else
None
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
residualmat
,
gamma
,
beta
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
x0mat
,
rowscale_const
,
out_numrows
,
None
,
residual_in_fp32
,
is_rms_norm
residualmat
,
gamma
,
beta
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
None
,
residual_in_fp32
,
is_rms_norm
,
)
)
# dmask is None if dropout_p == 0.0
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return
zmat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask
,
mu
,
rsigma
return
zmat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask
,
mu
,
rsigma
def
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
def
_dropout_add_layer_norm_subset_backward
(
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
dz
,
x0_numrows
,
has_residual
,
is_rms_norm
=
False
):
dx
,
""" Assume that arguments are contiguous and aligned to 16 bytes
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
,
is_rms_norm
=
False
,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
x0 must not be None if we have colscale.
...
@@ -94,10 +181,25 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
...
@@ -94,10 +181,25 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
x0_subset
=
x0_subset
.
view
(
-
1
)
if
x0_subset
is
not
None
else
None
x0_subset
=
x0_subset
.
view
(
-
1
)
if
x0_subset
is
not
None
else
None
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
if
colscale
is
not
None
:
if
colscale
is
not
None
:
assert
x0
is
not
None
,
'
x0 is required to compute the gradient of colscale
'
assert
x0
is
not
None
,
"
x0 is required to compute the gradient of colscale
"
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
None
,
colscale
,
x0_subset
,
out_subset
,
dzmat
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
,
is_rms_norm
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
,
is_rms_norm
,
)
)
# dresidualmat is None if not has_residual
# dresidualmat is None if not has_residual
if
colscale
is
None
:
if
colscale
is
None
:
...
@@ -108,18 +210,44 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
...
@@ -108,18 +210,44 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
def
_dropout_add_layer_norm_parallel_residual_forward
(
def
_dropout_add_layer_norm_parallel_residual_forward
(
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
x0
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
,
):
):
""" Assume that arguments are contiguous and aligned to 16 bytes
"""Assume that arguments are contiguous and aligned to 16 bytes"""
"""
hidden_size
=
gamma0
.
numel
()
hidden_size
=
gamma0
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
x1mat
=
x1
.
view
((
-
1
,
hidden_size
))
if
x1
is
not
None
else
None
x1mat
=
x1
.
view
((
-
1
,
hidden_size
))
if
x1
is
not
None
else
None
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
z0mat
,
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_parallel_residual_fwd
(
(
x0mat
,
x1mat
,
residualmat
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
z0mat
,
None
,
residual_in_fp32
,
is_rms_norm
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
,
)
=
dropout_layer_norm
.
dropout_add_ln_parallel_residual_fwd
(
x0mat
,
x1mat
,
residualmat
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
None
,
residual_in_fp32
,
is_rms_norm
,
)
)
# dmask0 and dmask1 are None if dropout_p == 0.0
# dmask0 and dmask1 are None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
...
@@ -127,10 +255,22 @@ def _dropout_add_layer_norm_parallel_residual_forward(
...
@@ -127,10 +255,22 @@ def _dropout_add_layer_norm_parallel_residual_forward(
def
_dropout_add_layer_norm_parallel_residual_backward
(
def
_dropout_add_layer_norm_parallel_residual_backward
(
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dz0
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
=
False
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
=
False
,
):
):
"""
Assume that arguments are contiguous and aligned to 16 bytes
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
(x = drop(x0) + residual was not returned in the fwd).
"""
"""
...
@@ -139,9 +279,30 @@ def _dropout_add_layer_norm_parallel_residual_backward(
...
@@ -139,9 +279,30 @@ def _dropout_add_layer_norm_parallel_residual_backward(
dz0mat
=
dz0
.
view
(
xmat
.
shape
)
dz0mat
=
dz0
.
view
(
xmat
.
shape
)
dz1mat
=
dz1
.
view
(
xmat
.
shape
)
if
dz1
is
not
None
else
None
dz1mat
=
dz1
.
view
(
xmat
.
shape
)
if
dz1
is
not
None
else
None
dxmat
=
dx
.
view
(
xmat
.
shape
)
if
dx
is
not
None
else
None
dxmat
=
dx
.
view
(
xmat
.
shape
)
if
dx
is
not
None
else
None
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_parallel_residual_bwd
(
(
dz0mat
,
dz1mat
,
dxmat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dx0mat
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
,
*
rest
,
)
=
dropout_layer_norm
.
dropout_add_ln_parallel_residual_bwd
(
dz0mat
,
dz1mat
,
dxmat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
,
)
)
# dresidualmat is None if not has_residual
# dresidualmat is None if not has_residual
return
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
return
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
...
@@ -149,8 +310,21 @@ def _dropout_add_layer_norm_parallel_residual_backward(
...
@@ -149,8 +310,21 @@ def _dropout_add_layer_norm_parallel_residual_backward(
class
DropoutAddLayerNormFn
(
torch
.
autograd
.
Function
):
class
DropoutAddLayerNormFn
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
def
forward
(
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
ctx
,
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
,
):
x0
=
maybe_align
(
x0
.
contiguous
(),
16
)
x0
=
maybe_align
(
x0
.
contiguous
(),
16
)
residual
=
maybe_align
(
residual
.
contiguous
(),
16
)
if
residual
is
not
None
else
None
residual
=
maybe_align
(
residual
.
contiguous
(),
16
)
if
residual
is
not
None
else
None
gamma
=
maybe_align
(
gamma
.
contiguous
(),
16
)
gamma
=
maybe_align
(
gamma
.
contiguous
(),
16
)
...
@@ -158,26 +332,43 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
...
@@ -158,26 +332,43 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
rowscale
=
maybe_align
(
rowscale
.
contiguous
(),
16
)
if
rowscale
is
not
None
else
None
rowscale
=
maybe_align
(
rowscale
.
contiguous
(),
16
)
if
rowscale
is
not
None
else
None
colscale
=
maybe_align
(
colscale
.
contiguous
(),
16
)
if
colscale
is
not
None
else
None
colscale
=
maybe_align
(
colscale
.
contiguous
(),
16
)
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_forward
(
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_forward
(
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
x0
,
residual_in_fp32
,
is_rms_norm
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
,
)
)
# Only need to save x0 if we need to compute gradient wrt colscale
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
x0_saved
=
x0
if
colscale
is
not
None
else
None
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
x0_saved
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
)
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
x0_saved
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
)
ctx
.
prenorm
=
prenorm
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
ctx
.
has_residual
=
residual
is
not
None
ctx
.
has_residual
=
residual
is
not
None
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_beta
=
beta
is
not
None
ctx
.
has_beta
=
beta
is
not
None
if
not
return_dmask
:
if
not
return_dmask
:
return
(
zmat
.
view
(
x0
.
shape
)
if
not
prenorm
return
(
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
)))
zmat
.
view
(
x0
.
shape
)
if
not
prenorm
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
))
)
else
:
else
:
dmask
=
(
dmask
.
view
(
x0
.
shape
)
if
dropout_p
>
0.
dmask
=
(
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
))
dmask
.
view
(
x0
.
shape
)
if
dropout_p
>
0.0
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
)
)
ctx
.
mark_non_differentiable
(
dmask
)
ctx
.
mark_non_differentiable
(
dmask
)
return
((
zmat
.
view
(
x0
.
shape
),
dmask
)
if
not
prenorm
return
(
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
),
dmask
))
(
zmat
.
view
(
x0
.
shape
),
dmask
)
if
not
prenorm
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
),
dmask
)
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
dz
,
*
args
):
def
backward
(
ctx
,
dz
,
*
args
):
...
@@ -189,35 +380,85 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
...
@@ -189,35 +380,85 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
dropout_p
=
ctx
.
dropout_p
dropout_p
=
ctx
.
dropout_p
has_residual
=
ctx
.
has_residual
has_residual
=
ctx
.
has_residual
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_backward
(
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
,
dz
,
ctx
.
is_rms_norm
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
,
ctx
.
is_rms_norm
,
)
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
return
(
dx0
,
dresidual
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
None
,
dcolscale
,
None
,
return
(
None
,
None
,
None
,
None
,
None
)
dx0
,
dresidual
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
,
)
class
DropoutAddLayerNormSubsetFn
(
torch
.
autograd
.
Function
):
class
DropoutAddLayerNormSubsetFn
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
def
forward
(
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
ctx
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
,
):
x0
=
maybe_align
(
x0
.
contiguous
(),
16
)
x0
=
maybe_align
(
x0
.
contiguous
(),
16
)
residual
=
maybe_align
(
residual
.
contiguous
(),
16
)
if
residual
is
not
None
else
None
residual
=
maybe_align
(
residual
.
contiguous
(),
16
)
if
residual
is
not
None
else
None
gamma
=
maybe_align
(
gamma
.
contiguous
(),
16
)
gamma
=
maybe_align
(
gamma
.
contiguous
(),
16
)
beta
=
maybe_align
(
beta
.
contiguous
(),
16
)
if
beta
is
not
None
else
None
beta
=
maybe_align
(
beta
.
contiguous
(),
16
)
if
beta
is
not
None
else
None
colscale
=
maybe_align
(
colscale
.
contiguous
(),
16
)
if
colscale
is
not
None
else
None
colscale
=
maybe_align
(
colscale
.
contiguous
(),
16
)
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_subset_forward
(
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_subset_forward
(
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
x0
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
is_rms_norm
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
is_rms_norm
,
)
)
# Only need to save x0 if we need to compute gradient wrt colscale
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
x0_saved
=
x0
if
colscale
is
not
None
else
None
x_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
x_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
ctx
.
save_for_backward
(
xmat
.
view
(
x_shape
),
x0_saved
,
dmask
,
gamma
,
mu
,
rsigma
,
colscale
,
ctx
.
save_for_backward
(
x0_subset
,
out_subset
)
xmat
.
view
(
x_shape
),
x0_saved
,
dmask
,
gamma
,
mu
,
rsigma
,
colscale
,
x0_subset
,
out_subset
)
ctx
.
prenorm
=
prenorm
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
ctx
.
rowscale_const
=
rowscale_const
ctx
.
rowscale_const
=
rowscale_const
...
@@ -227,14 +468,16 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
...
@@ -227,14 +468,16 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
ctx
.
has_beta
=
beta
is
not
None
ctx
.
has_beta
=
beta
is
not
None
z_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
z_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
if
not
return_dmask
:
if
not
return_dmask
:
return
(
zmat
.
view
(
z_shape
)
if
not
prenorm
return
zmat
.
view
(
z_shape
)
if
not
prenorm
else
(
zmat
.
view
(
z_shape
),
xmat
.
view
(
x0
.
shape
))
else
(
zmat
.
view
(
z_shape
),
xmat
.
view
(
x0
.
shape
)))
else
:
else
:
z
=
zmat
.
view
(
z_shape
)
z
=
zmat
.
view
(
z_shape
)
dmask
=
(
dmask
.
view
(
x0
.
shape
)
if
dropout_p
>
0.
dmask
=
(
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
))
dmask
.
view
(
x0
.
shape
)
if
dropout_p
>
0.0
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
)
)
ctx
.
mark_non_differentiable
(
dmask
)
ctx
.
mark_non_differentiable
(
dmask
)
return
(
(
z
,
dmask
)
if
not
prenorm
else
(
z
,
xmat
.
view
(
x_shape
),
dmask
)
)
return
(
z
,
dmask
)
if
not
prenorm
else
(
z
,
xmat
.
view
(
x_shape
),
dmask
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
dz
,
*
args
):
def
backward
(
ctx
,
dz
,
*
args
):
...
@@ -246,20 +489,63 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
...
@@ -246,20 +489,63 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
dropout_p
=
ctx
.
dropout_p
dropout_p
=
ctx
.
dropout_p
has_residual
=
ctx
.
has_residual
has_residual
=
ctx
.
has_residual
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_subset_backward
(
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
dz
,
ctx
.
rowscale_const
,
ctx
.
x0_numrows
,
has_residual
,
ctx
.
is_rms_norm
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
ctx
.
rowscale_const
,
ctx
.
x0_numrows
,
has_residual
,
ctx
.
is_rms_norm
,
)
)
dx0
=
dx0mat
.
view
(
-
1
,
*
x
.
shape
[
1
:])
dx0
=
dx0mat
.
view
(
-
1
,
*
x
.
shape
[
1
:])
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
return
(
dx0
,
dresidual
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
dcolscale
,
None
,
None
,
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
dx0
,
dresidual
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
class
DropoutAddLayerNormParallelResidualFn
(
torch
.
autograd
.
Function
):
class
DropoutAddLayerNormParallelResidualFn
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
def
forward
(
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
ctx
,
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
,
):
x0
=
maybe_align
(
x0
.
contiguous
(),
16
)
x0
=
maybe_align
(
x0
.
contiguous
(),
16
)
x1
=
maybe_align
(
x1
.
contiguous
(),
16
)
if
x1
is
not
None
else
None
x1
=
maybe_align
(
x1
.
contiguous
(),
16
)
if
x1
is
not
None
else
None
residual
=
maybe_align
(
residual
.
contiguous
(),
16
)
if
residual
is
not
None
else
None
residual
=
maybe_align
(
residual
.
contiguous
(),
16
)
if
residual
is
not
None
else
None
...
@@ -267,9 +553,26 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
...
@@ -267,9 +553,26 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
beta0
=
maybe_align
(
beta0
.
contiguous
(),
16
)
if
beta0
is
not
None
else
None
beta0
=
maybe_align
(
beta0
.
contiguous
(),
16
)
if
beta0
is
not
None
else
None
gamma1
=
maybe_align
(
gamma1
.
contiguous
(),
16
)
if
gamma1
is
not
None
else
None
gamma1
=
maybe_align
(
gamma1
.
contiguous
(),
16
)
if
gamma1
is
not
None
else
None
beta1
=
maybe_align
(
beta1
.
contiguous
(),
16
)
if
beta1
is
not
None
else
None
beta1
=
maybe_align
(
beta1
.
contiguous
(),
16
)
if
beta1
is
not
None
else
None
z0mat
,
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
=
_dropout_add_layer_norm_parallel_residual_forward
(
(
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
z0mat
,
residual_in_fp32
,
is_rms_norm
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
,
)
=
_dropout_add_layer_norm_parallel_residual_forward
(
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
,
)
)
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
dmask0
,
dmask1
,
gamma0
,
gamma1
,
mu
,
rsigma
)
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
dmask0
,
dmask1
,
gamma0
,
gamma1
,
mu
,
rsigma
)
ctx
.
prenorm
=
prenorm
ctx
.
prenorm
=
prenorm
...
@@ -282,13 +585,21 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
...
@@ -282,13 +585,21 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
if
not
return_dmask
:
if
not
return_dmask
:
return
z
if
not
prenorm
else
(
*
z
,
xmat
.
view
(
x0
.
shape
))
return
z
if
not
prenorm
else
(
*
z
,
xmat
.
view
(
x0
.
shape
))
else
:
else
:
dmask0
=
(
dmask0
.
view
(
x0
.
shape
)
if
dropout_p
>
0.
dmask0
=
(
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
))
dmask0
.
view
(
x0
.
shape
)
dmask1
=
(
dmask1
.
view
(
x0
.
shape
)
if
dropout_p
>
0.
and
x1
is
not
None
if
dropout_p
>
0.0
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
))
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
)
)
dmask1
=
(
dmask1
.
view
(
x0
.
shape
)
if
dropout_p
>
0.0
and
x1
is
not
None
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
)
)
ctx
.
mark_non_differentiable
(
dmask0
)
ctx
.
mark_non_differentiable
(
dmask0
)
ctx
.
mark_non_differentiable
(
dmask1
)
ctx
.
mark_non_differentiable
(
dmask1
)
return
(
*
z
,
dmask0
,
dmask1
)
if
not
prenorm
else
(
*
z
,
xmat
.
view
(
x0
.
shape
),
dmask0
,
dmask1
)
return
(
(
*
z
,
dmask0
,
dmask1
)
if
not
prenorm
else
(
*
z
,
xmat
.
view
(
x0
.
shape
),
dmask0
,
dmask1
)
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
dz0
,
dz1
,
*
args
):
def
backward
(
ctx
,
dz0
,
dz1
,
*
args
):
...
@@ -299,63 +610,170 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
...
@@ -299,63 +610,170 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
dropout_p
=
ctx
.
dropout_p
dropout_p
=
ctx
.
dropout_p
has_x1
=
ctx
.
has_x1
has_x1
=
ctx
.
has_x1
has_residual
=
ctx
.
has_residual
has_residual
=
ctx
.
has_residual
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
=
_dropout_add_layer_norm_parallel_residual_backward
(
(
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
dx0mat
,
has_residual
,
ctx
.
is_rms_norm
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
,
)
=
_dropout_add_layer_norm_parallel_residual_backward
(
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
ctx
.
is_rms_norm
,
)
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
return
(
dx0
,
dx1
,
dresidual
,
dgamma0
,
dbeta0
if
ctx
.
has_beta
else
None
,
dgamma1
,
return
(
dbeta1
if
ctx
.
has_beta
else
None
,
None
,
None
,
None
,
None
,
None
,
None
)
dx0
,
dx1
,
dresidual
,
dgamma0
,
dbeta0
if
ctx
.
has_beta
else
None
,
dgamma1
,
dbeta1
if
ctx
.
has_beta
else
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
def
layer_norm
(
x
,
weight
,
bias
,
epsilon
):
def
layer_norm
(
x
,
weight
,
bias
,
epsilon
):
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
bias
,
None
,
None
,
0.0
,
epsilon
,
False
)
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
bias
,
None
,
None
,
0.0
,
epsilon
,
False
)
def
dropout_add_layer_norm
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
def
dropout_add_layer_norm
(
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
x0
,
return_dropout_mask
=
False
):
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
Otherwise residual dtype is residual.dtype.
"""
"""
return
DropoutAddLayerNormFn
.
apply
(
return
DropoutAddLayerNormFn
.
apply
(
x0
,
residual
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
x0
,
False
,
return_dropout_mask
residual
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
,
)
)
def
dropout_add_layer_norm_subset
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
def
dropout_add_layer_norm_subset
(
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
x0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
residual
,
return_dropout_mask
=
False
):
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
Otherwise residual dtype is residual.dtype.
"""
"""
return
DropoutAddLayerNormSubsetFn
.
apply
(
return
DropoutAddLayerNormSubsetFn
.
apply
(
x0
,
residual
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
x0
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
residual
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
,
)
)
def
dropout_add_layer_norm_parallel_residual
(
def
dropout_add_layer_norm_parallel_residual
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
False
,
x0
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
):
"""residual_in_fp32 only has an effect if residual is None.
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
Otherwise residual dtype is residual.dtype.
"""
"""
return
DropoutAddLayerNormParallelResidualFn
.
apply
(
return
DropoutAddLayerNormParallelResidualFn
.
apply
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
x0
,
False
,
return_dropout_mask
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
,
)
)
class
DropoutAddLayerNorm
(
torch
.
nn
.
Module
):
class
DropoutAddLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
def
__init__
(
device
=
None
,
dtype
=
None
):
self
,
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
self
.
prenorm
=
prenorm
self
.
prenorm
=
prenorm
self
.
p
=
p
self
.
p
=
p
...
@@ -370,6 +788,13 @@ class DropoutAddLayerNorm(torch.nn.Module):
...
@@ -370,6 +788,13 @@ class DropoutAddLayerNorm(torch.nn.Module):
init
.
zeros_
(
self
.
bias
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
x0
,
residual
=
None
):
def
forward
(
self
,
x0
,
residual
=
None
):
return
dropout_add_layer_norm
(
x0
,
residual
,
self
.
weight
,
self
.
bias
,
return
dropout_add_layer_norm
(
self
.
p
if
self
.
training
else
0.0
,
self
.
eps
,
x0
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
)
residual
,
self
.
weight
,
self
.
bias
,
self
.
p
if
self
.
training
else
0.0
,
self
.
eps
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
,
)
flash_attn/ops/rms_norm.py
View file @
f1a73d07
...
@@ -4,60 +4,130 @@
...
@@ -4,60 +4,130 @@
import
torch
import
torch
from
torch.nn
import
init
from
torch.nn
import
init
from
flash_attn.ops.layer_norm
import
DropoutAddLayerNormFn
,
DropoutAddLayerNormSubsetFn
from
flash_attn.ops.layer_norm
import
(
from
flash_attn.ops.layer_norm
import
DropoutAddLayerNormParallelResidualFn
DropoutAddLayerNormFn
,
DropoutAddLayerNormParallelResidualFn
,
DropoutAddLayerNormSubsetFn
,
)
def
rms_norm
(
x
,
weight
,
epsilon
):
def
rms_norm
(
x
,
weight
,
epsilon
):
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
None
,
None
,
None
,
0.0
,
epsilon
,
False
,
return
DropoutAddLayerNormFn
.
apply
(
False
,
True
)
x
,
None
,
weight
,
None
,
None
,
None
,
0.0
,
epsilon
,
False
,
False
,
True
)
def
dropout_add_rms_norm
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
def
dropout_add_rms_norm
(
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
x0
,
return_dropout_mask
=
False
):
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
Otherwise residual dtype is residual.dtype.
"""
"""
return
DropoutAddLayerNormFn
.
apply
(
return
DropoutAddLayerNormFn
.
apply
(
x0
,
residual
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
x0
,
True
,
return_dropout_mask
residual
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
,
)
)
def
dropout_add_rms_norm_subset
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
def
dropout_add_rms_norm_subset
(
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
x0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
residual
,
return_dropout_mask
=
False
):
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
"""residual_in_fp32 only has an effect if residual is None.
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
Otherwise residual dtype is residual.dtype.
"""
"""
return
DropoutAddLayerNormSubsetFn
.
apply
(
return
DropoutAddLayerNormSubsetFn
.
apply
(
x0
,
residual
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
x0
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
residual
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
,
)
)
def
dropout_add_rms_norm_parallel_residual
(
def
dropout_add_rms_norm_parallel_residual
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
x0
,
dropout_p
,
epsilon
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
,
):
):
"""residual_in_fp32 only has an effect if residual is None.
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
Otherwise residual dtype is residual.dtype.
"""
"""
return
DropoutAddLayerNormParallelResidualFn
.
apply
(
return
DropoutAddLayerNormParallelResidualFn
.
apply
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
x0
,
True
,
return_dropout_mask
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
,
)
)
class
RMSNorm
(
torch
.
nn
.
Module
):
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
device
=
None
,
dtype
=
None
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'
device
'
:
device
,
'
dtype
'
:
dtype
}
factory_kwargs
=
{
"
device
"
:
device
,
"
dtype
"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
self
.
eps
=
eps
self
.
eps
=
eps
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
'
bias
'
,
None
)
self
.
register_parameter
(
"
bias
"
,
None
)
self
.
reset_parameters
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
...
@@ -68,22 +138,37 @@ class RMSNorm(torch.nn.Module):
...
@@ -68,22 +138,37 @@ class RMSNorm(torch.nn.Module):
class
DropoutAddRMSNorm
(
torch
.
nn
.
Module
):
class
DropoutAddRMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
def
__init__
(
device
=
None
,
dtype
=
None
):
self
,
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
self
.
prenorm
=
prenorm
self
.
prenorm
=
prenorm
self
.
p
=
p
self
.
p
=
p
self
.
eps
=
eps
self
.
eps
=
eps
self
.
residual_in_fp32
=
residual_in_fp32
self
.
residual_in_fp32
=
residual_in_fp32
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
'
bias
'
,
None
)
self
.
register_parameter
(
"
bias
"
,
None
)
self
.
reset_parameters
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x0
,
residual
=
None
):
def
forward
(
self
,
x0
,
residual
=
None
):
return
dropout_add_rms_norm
(
x0
,
residual
,
self
.
weight
,
None
,
return
dropout_add_rms_norm
(
self
.
p
if
self
.
training
else
0.0
,
self
.
eps
,
x0
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
)
residual
,
self
.
weight
,
None
,
self
.
p
if
self
.
training
else
0.0
,
self
.
eps
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
,
)
flash_attn/ops/triton/k_activations.py
View file @
f1a73d07
...
@@ -11,7 +11,6 @@ from typing import Optional
...
@@ -11,7 +11,6 @@ from typing import Optional
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
_sqrt2pi
=
math
.
sqrt
(
2.0
/
math
.
pi
)
_sqrt2pi
=
math
.
sqrt
(
2.0
/
math
.
pi
)
_sqrt1_2
=
math
.
sqrt
(
1.0
/
2
)
_sqrt1_2
=
math
.
sqrt
(
1.0
/
2
)
_gaussian_pdf_normalization
=
1.0
/
math
.
sqrt
(
2
*
math
.
pi
)
_gaussian_pdf_normalization
=
1.0
/
math
.
sqrt
(
2
*
math
.
pi
)
...
@@ -142,6 +141,7 @@ def gelu_grad(x):
...
@@ -142,6 +141,7 @@ def gelu_grad(x):
pdf
=
tl
.
exp
(
-
0.5
*
x
*
x
)
*
_gaussian_pdf_normalization
pdf
=
tl
.
exp
(
-
0.5
*
x
*
x
)
*
_gaussian_pdf_normalization
return
cdf
+
x
*
pdf
return
cdf
+
x
*
pdf
@
triton
.
jit
@
triton
.
jit
def
gelu_approx
(
x
):
def
gelu_approx
(
x
):
"""
"""
...
@@ -157,6 +157,6 @@ def gelu_approx_grad(x):
...
@@ -157,6 +157,6 @@ def gelu_approx_grad(x):
# CREDITS: Fast implementation proposed in
# CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out
=
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
tanh_out
=
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
return
0.5
*
x
*
(
return
0.5
*
x
*
(
(
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
(
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
)
1
+
tanh_out
)
+
0.5
*
(
1
+
tanh_out
)
)
flash_attn/ops/triton/linear.py
View file @
f1a73d07
...
@@ -9,8 +9,14 @@ from torch.autograd.function import FunctionCtx
...
@@ -9,8 +9,14 @@ from torch.autograd.function import FunctionCtx
from
torch.cuda.amp
import
custom_fwd
from
torch.cuda.amp
import
custom_fwd
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
flash_attn.ops.triton.k_activations
import
gelu
,
gelu_grad
,
gelu_approx
,
gelu_approx_grad
,
squared_relu
,
squared_relu_grad
from
flash_attn.ops.triton.k_activations
import
(
gelu
,
gelu_approx
,
gelu_approx_grad
,
gelu_grad
,
squared_relu
,
squared_relu_grad
,
)
# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
...
@@ -28,7 +34,12 @@ def get_configs_io_bound():
...
@@ -28,7 +34,12 @@ def get_configs_io_bound():
num_warps
=
2
if
block_n
<=
64
else
4
num_warps
=
2
if
block_n
<=
64
else
4
configs
.
append
(
configs
.
append
(
triton
.
Config
(
triton
.
Config
(
{
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"BLOCK_K"
:
block_k
,
"SPLIT_K"
:
1
},
{
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"BLOCK_K"
:
block_k
,
"SPLIT_K"
:
1
,
},
num_stages
=
num_stages
,
num_stages
=
num_stages
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
)
)
...
@@ -43,29 +54,75 @@ def get_configs_io_bound():
...
@@ -43,29 +54,75 @@ def get_configs_io_bound():
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
# good for int8
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
num_stages
=
3
,
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
num_warps
=
8
,
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
num_stages
=
3
,
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
]
]
+
get_configs_io_bound
(),
+
get_configs_io_bound
(),
key
=
[
"CACHE_KEY_M"
,
"CACHE_KEY_N"
,
"CACHE_KEY_K"
],
key
=
[
"CACHE_KEY_M"
,
"CACHE_KEY_N"
,
"CACHE_KEY_K"
],
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
},
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
,
},
)
)
@
triton
.
heuristics
(
@
triton
.
heuristics
(
{
{
...
@@ -204,7 +261,7 @@ def triton_linear_act(
...
@@ -204,7 +261,7 @@ def triton_linear_act(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
'
id
'
,
activation
:
str
=
"
id
"
,
save_act_input
:
bool
=
False
,
save_act_input
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
...
@@ -221,7 +278,7 @@ def triton_linear_act(
...
@@ -221,7 +278,7 @@ def triton_linear_act(
# dtype = torch.get_autocast_gpu_dtype()
# dtype = torch.get_autocast_gpu_dtype()
# x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
# x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
assert
activation
in
[
'
id
'
,
'
gelu
'
,
'
gelu_approx
'
,
'
squared_relu
'
]
assert
activation
in
[
"
id
"
,
"
gelu
"
,
"
gelu_approx
"
,
"
squared_relu
"
]
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
batch_dim
=
batch_shape
.
numel
()
...
@@ -233,12 +290,20 @@ def triton_linear_act(
...
@@ -233,12 +290,20 @@ def triton_linear_act(
weight
=
weight
.
contiguous
()
weight
=
weight
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
assert
x
.
dtype
==
weight
.
dtype
,
f
"Input and weight must have the same dtype, got
{
x
.
dtype
}
and
{
weight
.
dtype
}
"
assert
(
x
.
dtype
==
weight
.
dtype
),
f
"Input and weight must have the same dtype, got
{
x
.
dtype
}
and
{
weight
.
dtype
}
"
if
bias
is
not
None
:
if
bias
is
not
None
:
assert
x
.
dtype
==
bias
.
dtype
,
f
"Input and bias must have the same dtype, got
{
x
.
dtype
}
and
{
bias
.
dtype
}
"
assert
(
assert
x_reshaped
.
shape
[
1
]
==
weight
.
shape
[
1
],
f
"Incompatible dimensions:
{
x_reshaped
.
shape
}
-
{
weight
.
shape
}
"
x
.
dtype
==
bias
.
dtype
),
f
"Input and bias must have the same dtype, got
{
x
.
dtype
}
and
{
bias
.
dtype
}
"
assert
(
x_reshaped
.
shape
[
1
]
==
weight
.
shape
[
1
]
),
f
"Incompatible dimensions:
{
x_reshaped
.
shape
}
-
{
weight
.
shape
}
"
assert
bias
is
None
or
bias
.
shape
[
0
]
==
weight
.
shape
[
0
],
"Incompatible dimensions in between weight and bias"
assert
(
bias
is
None
or
bias
.
shape
[
0
]
==
weight
.
shape
[
0
]
),
"Incompatible dimensions in between weight and bias"
M
,
K
=
x_reshaped
.
shape
M
,
K
=
x_reshaped
.
shape
N
,
K
=
weight
.
shape
N
,
K
=
weight
.
shape
...
@@ -278,35 +343,83 @@ def triton_linear_act(
...
@@ -278,35 +343,83 @@ def triton_linear_act(
if
not
save_act_input
:
if
not
save_act_input
:
return
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
])
return
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
])
else
:
else
:
return
(
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
]),
return
(
act_input
.
reshape
(
*
batch_shape
,
act_input
.
shape
[
-
1
]))
output
.
reshape
(
*
batch_shape
,
output
.
shape
[
-
1
]),
act_input
.
reshape
(
*
batch_shape
,
act_input
.
shape
[
-
1
]),
)
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
),
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
32
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
# good for int8
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
3
,
num_warps
=
8
),
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
num_stages
=
3
,
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
num_warps
=
8
,
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
num_stages
=
3
,
triton
.
Config
({
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
256
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
128
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
128
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
32
,
"BLOCK_K"
:
64
,
"SPLIT_K"
:
1
},
num_stages
=
5
,
num_warps
=
2
),
]
]
+
get_configs_io_bound
(),
+
get_configs_io_bound
(),
key
=
[
"CACHE_KEY_M"
,
"CACHE_KEY_N"
,
"CACHE_KEY_K"
],
key
=
[
"CACHE_KEY_M"
,
"CACHE_KEY_N"
,
"CACHE_KEY_K"
],
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
},
prune_configs_by
=
{
"early_config_prune"
:
early_config_prune
,
"perf_model"
:
estimate_matmul_time
,
"top_k"
:
10
,
},
)
)
@
triton
.
heuristics
(
@
triton
.
heuristics
(
{
{
...
@@ -395,7 +508,7 @@ def kernel_bwd(
...
@@ -395,7 +508,7 @@ def kernel_bwd(
B
+=
BLOCK_K
*
stride_bk
B
+=
BLOCK_K
*
stride_bk
# optional: fused activation (while the data is in shared memory)
# optional: fused activation (while the data is in shared memory)
if
ACTIVATION
!=
'
id
'
:
if
ACTIVATION
!=
"
id
"
:
act_in_ptrs
=
ACT_INPUT
+
ram
[:,
None
]
*
stride_cm
+
rbn
[
None
,
:]
act_in_ptrs
=
ACT_INPUT
+
ram
[:,
None
]
*
stride_cm
+
rbn
[
None
,
:]
act_input
=
tl
.
load
(
act_in_ptrs
).
to
(
acc
.
dtype
)
act_input
=
tl
.
load
(
act_in_ptrs
).
to
(
acc
.
dtype
)
if
ACTIVATION
==
"gelu"
:
if
ACTIVATION
==
"gelu"
:
...
@@ -418,7 +531,7 @@ def kernel_bwd(
...
@@ -418,7 +531,7 @@ def kernel_bwd(
def
triton_dgrad_act
(
def
triton_dgrad_act
(
grad_output
:
torch
.
Tensor
,
grad_output
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
activation
:
str
=
'
id
'
,
activation
:
str
=
"
id
"
,
act_input
:
Optional
[
torch
.
Tensor
]
=
None
,
act_input
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
...
@@ -430,7 +543,7 @@ def triton_dgrad_act(
...
@@ -430,7 +543,7 @@ def triton_dgrad_act(
:param act_input: an optional tensor to save the activation inputs (for backward)
:param act_input: an optional tensor to save the activation inputs (for backward)
:return: result tensor
:return: result tensor
"""
"""
assert
activation
in
[
'
id
'
,
'
gelu
'
,
'
gelu_approx
'
,
'
squared_relu
'
]
assert
activation
in
[
"
id
"
,
"
gelu
"
,
"
gelu_approx
"
,
"
squared_relu
"
]
batch_shape
,
n
=
grad_output
.
shape
[:
-
1
],
grad_output
.
shape
[
-
1
]
batch_shape
,
n
=
grad_output
.
shape
[:
-
1
],
grad_output
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
batch_dim
=
batch_shape
.
numel
()
...
@@ -441,10 +554,14 @@ def triton_dgrad_act(
...
@@ -441,10 +554,14 @@ def triton_dgrad_act(
if
weight
.
stride
(
0
)
>
1
and
weight
.
stride
(
1
)
>
1
:
if
weight
.
stride
(
0
)
>
1
and
weight
.
stride
(
1
)
>
1
:
weight
=
weight
.
contiguous
()
weight
=
weight
.
contiguous
()
assert
grad_output
.
dtype
==
weight
.
dtype
,
f
"grad_output and weight must have the same dtype, got
{
grad_output
.
dtype
}
and
{
weight
.
dtype
}
"
assert
(
assert
grad_output_reshaped
.
shape
[
1
]
==
weight
.
shape
[
0
],
f
"Incompatible dimensions:
{
grad_output_reshaped
.
shape
}
-
{
weight
.
shape
}
"
grad_output
.
dtype
==
weight
.
dtype
if
activation
!=
'id'
:
),
f
"grad_output and weight must have the same dtype, got
{
grad_output
.
dtype
}
and
{
weight
.
dtype
}
"
assert
act_input
is
not
None
,
f
'act_input is required for activation
{
activation
}
'
assert
(
grad_output_reshaped
.
shape
[
1
]
==
weight
.
shape
[
0
]
),
f
"Incompatible dimensions:
{
grad_output_reshaped
.
shape
}
-
{
weight
.
shape
}
"
if
activation
!=
"id"
:
assert
act_input
is
not
None
,
f
"act_input is required for activation
{
activation
}
"
# M, N, K in bwd are different from M, N, K in fwd
# M, N, K in bwd are different from M, N, K in fwd
M
,
K
=
grad_output_reshaped
.
shape
M
,
K
=
grad_output_reshaped
.
shape
...
...
flash_attn/ops/triton/mlp.py
View file @
f1a73d07
# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared
# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared
# to naive implementation.
# to naive implementation.
import
fused_dense_lib
as
fused_dense_cuda
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
import
fused_dense_lib
as
fused_dense_cuda
from
flash_attn.ops.activations
import
sqrelu_bwd
,
sqrelu_fwd
from
flash_attn.ops.triton.linear
import
triton_dgrad_act
,
triton_linear_act
from
flash_attn.ops.triton.linear
import
triton_linear_act
,
triton_dgrad_act
from
flash_attn.ops.activations
import
sqrelu_fwd
,
sqrelu_bwd
class
FusedDenseSqreluDenseFunc
(
torch
.
autograd
.
Function
):
class
FusedDenseSqreluDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
@
custom_fwd
@
custom_fwd
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
checkpoint_lvl
=
0
):
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
checkpoint_lvl
=
0
):
...
@@ -23,8 +21,9 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
...
@@ -23,8 +21,9 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
"""
"""
if
torch
.
is_autocast_enabled
():
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight1
,
bias1
,
weight2
,
bias2
=
[
a
.
to
(
dtype
=
dtype
)
x
,
weight1
,
bias1
,
weight2
,
bias2
=
[
for
a
in
[
x
,
weight1
,
bias1
,
weight2
,
bias2
]]
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight1
,
bias1
,
weight2
,
bias2
]
]
is_bf16
=
x
.
dtype
==
torch
.
bfloat16
is_bf16
=
x
.
dtype
==
torch
.
bfloat16
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
x
=
x
.
contiguous
()
x
=
x
.
contiguous
()
...
@@ -35,13 +34,18 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
...
@@ -35,13 +34,18 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
batch_dim
=
batch_shape
.
numel
()
if
is_bf16
:
if
is_bf16
:
act_input
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
)
act_input
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
)
output1
=
sqrelu_fwd
(
act_input
)
output1
=
sqrelu_fwd
(
act_input
)
else
:
else
:
save_act_input
=
checkpoint_lvl
!=
2
save_act_input
=
checkpoint_lvl
!=
2
result
=
triton_linear_act
(
result
=
triton_linear_act
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
activation
=
'squared_relu'
,
x
.
reshape
(
batch_dim
,
n
),
save_act_input
=
save_act_input
weight1
,
bias1
,
activation
=
"squared_relu"
,
save_act_input
=
save_act_input
,
)
)
if
save_act_input
:
if
save_act_input
:
output1
,
act_input
=
result
output1
,
act_input
=
result
...
@@ -69,16 +73,21 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
...
@@ -69,16 +73,21 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
if
checkpoint_lvl
==
0
:
if
checkpoint_lvl
==
0
:
act_input
,
output1
=
rest
act_input
,
output1
=
rest
elif
checkpoint_lvl
==
1
:
elif
checkpoint_lvl
==
1
:
act_input
,
=
rest
(
act_input
,
)
=
rest
output1
=
sqrelu_fwd
(
act_input
)
output1
=
sqrelu_fwd
(
act_input
)
elif
checkpoint_lvl
==
2
:
elif
checkpoint_lvl
==
2
:
if
is_bf16
:
if
is_bf16
:
act_input
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
)
act_input
=
fused_dense_cuda
.
linear_bias_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
)
output1
=
sqrelu_fwd
(
act_input
)
output1
=
sqrelu_fwd
(
act_input
)
else
:
else
:
output1
,
act_input
=
triton_linear_act
(
output1
,
act_input
=
triton_linear_act
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
activation
=
'squared_relu'
,
x
.
reshape
(
batch_dim
,
n
),
save_act_input
=
True
weight1
,
bias1
,
activation
=
"squared_relu"
,
save_act_input
=
True
,
)
)
if
is_bf16
:
if
is_bf16
:
...
@@ -92,8 +101,9 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
...
@@ -92,8 +101,9 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
else
:
else
:
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
)
grad_weight2
,
grad_bias2
=
fused_dense_cuda
.
linear_bias_wgrad
(
output1
,
grad_output
)
grad_act_input
=
triton_dgrad_act
(
grad_output
,
weight2
,
activation
=
'squared_relu'
,
grad_act_input
=
triton_dgrad_act
(
act_input
=
act_input
)
grad_output
,
weight2
,
activation
=
"squared_relu"
,
act_input
=
act_input
)
grad_input
,
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_backward
(
grad_input
,
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_backward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
grad_act_input
x
.
reshape
(
batch_dim
,
n
),
weight1
,
grad_act_input
)
)
...
@@ -104,9 +114,17 @@ fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply
...
@@ -104,9 +114,17 @@ fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply
class
FusedDenseSqreluDense
(
nn
.
Module
):
class
FusedDenseSqreluDense
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
bias1
=
True
,
bias2
=
True
,
self
,
checkpoint_lvl
=
0
,
device
=
None
,
dtype
=
None
):
in_features
,
hidden_features
=
None
,
out_features
=
None
,
bias1
=
True
,
bias2
=
True
,
checkpoint_lvl
=
0
,
device
=
None
,
dtype
=
None
,
):
"""
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
0: no recomputation in the bwd
...
@@ -114,7 +132,7 @@ class FusedDenseSqreluDense(nn.Module):
...
@@ -114,7 +132,7 @@ class FusedDenseSqreluDense(nn.Module):
2: recompute gelu_in and gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
"""
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
factory_kwargs
=
{
'
device
'
:
device
,
'
dtype
'
:
dtype
}
factory_kwargs
=
{
"
device
"
:
device
,
"
dtype
"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
out_features
=
out_features
or
in_features
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
*
4
hidden_features
=
hidden_features
or
in_features
*
4
...
@@ -126,6 +144,6 @@ class FusedDenseSqreluDense(nn.Module):
...
@@ -126,6 +144,6 @@ class FusedDenseSqreluDense(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
assert
x
.
is_cuda
assert
x
.
is_cuda
return
fused_dense_sqrelu_dense_function
(
x
,
self
.
fc1
.
weight
,
self
.
fc1
.
bias
,
return
fused_dense_sqrelu_dense_function
(
self
.
fc2
.
weight
,
self
.
fc2
.
bias
,
x
,
self
.
fc1
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
weight
,
self
.
fc2
.
bias
,
self
.
checkpoint_lvl
self
.
checkpoint_lvl
)
)
flash_attn/utils/benchmark.py
View file @
f1a73d07
...
@@ -5,31 +5,43 @@ import torch
...
@@ -5,31 +5,43 @@ import torch
import
torch.utils.benchmark
as
benchmark
import
torch.utils.benchmark
as
benchmark
def
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
def
benchmark_forward
(
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
fn
,
*
inputs
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
""" Use Pytorch Benchmark on the forward pass of an arbitrary function. """
):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
if
verbose
:
if
verbose
:
print
(
desc
,
'- Forward pass'
)
print
(
desc
,
"- Forward pass"
)
def
amp_wrapper
(
*
inputs
,
**
kwinputs
):
def
amp_wrapper
(
*
inputs
,
**
kwinputs
):
with
torch
.
autocast
(
device_type
=
'
cuda
'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
"
cuda
"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
fn
(
*
inputs
,
**
kwinputs
)
fn
(
*
inputs
,
**
kwinputs
)
t
=
benchmark
.
Timer
(
t
=
benchmark
.
Timer
(
stmt
=
'
fn_amp(*inputs, **kwinputs)
'
,
stmt
=
"
fn_amp(*inputs, **kwinputs)
"
,
globals
=
{
'
fn_amp
'
:
amp_wrapper
,
'
inputs
'
:
inputs
,
'
kwinputs
'
:
kwinputs
},
globals
=
{
"
fn_amp
"
:
amp_wrapper
,
"
inputs
"
:
inputs
,
"
kwinputs
"
:
kwinputs
},
num_threads
=
torch
.
get_num_threads
(),
num_threads
=
torch
.
get_num_threads
(),
)
)
m
=
t
.
timeit
(
repeats
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
if
verbose
:
print
(
m
)
print
(
m
)
return
t
,
m
return
t
,
m
def
benchmark_backward
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
def
benchmark_backward
(
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
fn
,
""" Use Pytorch Benchmark on the backward pass of an arbitrary function. """
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the backward pass of an arbitrary function."""
if
verbose
:
if
verbose
:
print
(
desc
,
'
- Backward pass
'
)
print
(
desc
,
"
- Backward pass
"
)
with
torch
.
autocast
(
device_type
=
'
cuda
'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
"
cuda
"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
y
=
y
[
0
]
...
@@ -37,7 +49,8 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
...
@@ -37,7 +49,8 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
grad
=
torch
.
randn_like
(
y
)
grad
=
torch
.
randn_like
(
y
)
else
:
else
:
if
grad
.
shape
!=
y
.
shape
:
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
'Grad shape does not match output shape'
)
raise
RuntimeError
(
"Grad shape does not match output shape"
)
def
f
(
*
inputs
,
y
,
grad
):
def
f
(
*
inputs
,
y
,
grad
):
# Set .grad to None to avoid extra operation of gradient accumulation
# Set .grad to None to avoid extra operation of gradient accumulation
for
x
in
inputs
:
for
x
in
inputs
:
...
@@ -46,22 +59,31 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
...
@@ -46,22 +59,31 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
y
.
backward
(
grad
,
retain_graph
=
True
)
y
.
backward
(
grad
,
retain_graph
=
True
)
t
=
benchmark
.
Timer
(
t
=
benchmark
.
Timer
(
stmt
=
'
f(*inputs, y=y, grad=grad)
'
,
stmt
=
"
f(*inputs, y=y, grad=grad)
"
,
globals
=
{
'f'
:
f
,
'
inputs
'
:
inputs
,
'y'
:
y
,
'
grad
'
:
grad
},
globals
=
{
"f"
:
f
,
"
inputs
"
:
inputs
,
"y"
:
y
,
"
grad
"
:
grad
},
num_threads
=
torch
.
get_num_threads
(),
num_threads
=
torch
.
get_num_threads
(),
)
)
m
=
t
.
timeit
(
repeats
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
if
verbose
:
print
(
m
)
print
(
m
)
return
t
,
m
return
t
,
m
def
benchmark_combined
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
def
benchmark_combined
(
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
fn
,
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
if
verbose
:
if
verbose
:
print
(
desc
,
'
- Forward + Backward pass
'
)
print
(
desc
,
"
- Forward + Backward pass
"
)
with
torch
.
autocast
(
device_type
=
'
cuda
'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
"
cuda
"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
y
=
y
[
0
]
...
@@ -69,68 +91,142 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
...
@@ -69,68 +91,142 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
grad
=
torch
.
randn_like
(
y
)
grad
=
torch
.
randn_like
(
y
)
else
:
else
:
if
grad
.
shape
!=
y
.
shape
:
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
'Grad shape does not match output shape'
)
raise
RuntimeError
(
"Grad shape does not match output shape"
)
def
f
(
grad
,
*
inputs
,
**
kwinputs
):
def
f
(
grad
,
*
inputs
,
**
kwinputs
):
for
x
in
inputs
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
'
cuda
'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
"
cuda
"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
y
=
y
[
0
]
y
.
backward
(
grad
,
retain_graph
=
True
)
y
.
backward
(
grad
,
retain_graph
=
True
)
t
=
benchmark
.
Timer
(
t
=
benchmark
.
Timer
(
stmt
=
'
f(grad, *inputs, **kwinputs)
'
,
stmt
=
"
f(grad, *inputs, **kwinputs)
"
,
globals
=
{
'f'
:
f
,
'
fn
'
:
fn
,
'
inputs
'
:
inputs
,
'
grad
'
:
grad
,
'
kwinputs
'
:
kwinputs
},
globals
=
{
"f"
:
f
,
"
fn
"
:
fn
,
"
inputs
"
:
inputs
,
"
grad
"
:
grad
,
"
kwinputs
"
:
kwinputs
},
num_threads
=
torch
.
get_num_threads
(),
num_threads
=
torch
.
get_num_threads
(),
)
)
m
=
t
.
timeit
(
repeats
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
if
verbose
:
print
(
m
)
print
(
m
)
return
t
,
m
return
t
,
m
def
benchmark_fwd_bwd
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
def
benchmark_fwd_bwd
(
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
fn
,
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return
(
return
(
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
benchmark_forward
(
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
fn
,
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
*
inputs
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
)
)
def
benchmark_all
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
def
benchmark_all
(
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
fn
,
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
""
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return
(
return
(
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
benchmark_forward
(
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
fn
,
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
*
inputs
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
repeats
=
repeats
,
benchmark_combined
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
desc
=
desc
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
benchmark_combined
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
,
),
)
)
def
pytorch_profiler
(
fn
,
*
inputs
,
trace_filename
=
None
,
backward
=
False
,
amp
=
False
,
def
pytorch_profiler
(
amp_dtype
=
torch
.
float16
,
cpu
=
False
,
verbose
=
True
,
**
kwinputs
):
fn
,
""" Wrap benchmark functions in Pytorch profiler to see CUDA information. """
*
inputs
,
trace_filename
=
None
,
backward
=
False
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
cpu
=
False
,
verbose
=
True
,
**
kwinputs
,
):
"""Wrap benchmark functions in Pytorch profiler to see CUDA information."""
if
backward
:
if
backward
:
with
torch
.
autocast
(
device_type
=
'
cuda
'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
"
cuda
"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
g
=
torch
.
randn_like
(
fn
(
*
inputs
,
**
kwinputs
))
g
=
torch
.
randn_like
(
fn
(
*
inputs
,
**
kwinputs
))
for
_
in
range
(
30
):
# Warm up
for
_
in
range
(
30
):
# Warm up
if
backward
:
if
backward
:
for
x
in
inputs
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
'
cuda
'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
"
cuda
"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
out
=
fn
(
*
inputs
,
**
kwinputs
)
out
=
fn
(
*
inputs
,
**
kwinputs
)
# Backward should be done outside autocast
# Backward should be done outside autocast
if
backward
:
if
backward
:
out
.
backward
(
g
,
retain_graph
=
True
)
out
.
backward
(
g
,
retain_graph
=
True
)
activities
=
([
torch
.
profiler
.
ProfilerActivity
.
CPU
]
if
cpu
else
[])
+
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
]
activities
=
([
torch
.
profiler
.
ProfilerActivity
.
CPU
]
if
cpu
else
[])
+
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
]
with
torch
.
profiler
.
profile
(
with
torch
.
profiler
.
profile
(
activities
=
activities
,
activities
=
activities
,
record_shapes
=
True
,
record_shapes
=
True
,
...
@@ -141,9 +237,10 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False
...
@@ -141,9 +237,10 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False
for
x
in
inputs
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
'
cuda
'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
"
cuda
"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
out
=
fn
(
*
inputs
,
**
kwinputs
)
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
backward
:
out
.
backward
(
g
,
retain_graph
=
True
)
if
backward
:
out
.
backward
(
g
,
retain_graph
=
True
)
if
verbose
:
if
verbose
:
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
print
(
prof
.
key_averages
().
table
(
row_limit
=
50
))
print
(
prof
.
key_averages
().
table
(
row_limit
=
50
))
...
@@ -151,14 +248,14 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False
...
@@ -151,14 +248,14 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False
prof
.
export_chrome_trace
(
trace_filename
)
prof
.
export_chrome_trace
(
trace_filename
)
def
benchmark_memory
(
fn
,
*
inputs
,
desc
=
''
,
verbose
=
True
,
**
kwinputs
):
def
benchmark_memory
(
fn
,
*
inputs
,
desc
=
""
,
verbose
=
True
,
**
kwinputs
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_peak_memory_stats
()
torch
.
cuda
.
reset_peak_memory_stats
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
fn
(
*
inputs
,
**
kwinputs
)
fn
(
*
inputs
,
**
kwinputs
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
mem
=
torch
.
cuda
.
max_memory_allocated
()
/
((
2
**
20
)
*
1000
)
mem
=
torch
.
cuda
.
max_memory_allocated
()
/
((
2
**
20
)
*
1000
)
if
verbose
:
if
verbose
:
print
(
f
'
{
desc
}
max memory:
{
mem
}
GB
'
)
print
(
f
"
{
desc
}
max memory:
{
mem
}
GB
"
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
mem
return
mem
flash_attn/utils/distributed.py
View file @
f1a73d07
...
@@ -17,10 +17,12 @@ if "reduce_scatter_tensor" not in dir(torch.distributed):
...
@@ -17,10 +17,12 @@ if "reduce_scatter_tensor" not in dir(torch.distributed):
# Raw operation, does not support autograd, but does support async
# Raw operation, does not support autograd, but does support async
def
all_gather_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
def
all_gather_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
output
=
torch
.
empty
(
world_size
*
input_
.
shape
[
0
],
*
input_
.
shape
[
1
:],
output
=
torch
.
empty
(
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
world_size
*
input_
.
shape
[
0
],
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
handle
=
torch
.
distributed
.
all_gather_into_tensor
(
output
,
input_
.
contiguous
(),
)
group
=
process_group
,
async_op
=
async_op
)
handle
=
torch
.
distributed
.
all_gather_into_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
return
output
,
handle
...
@@ -28,11 +30,12 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool =
...
@@ -28,11 +30,12 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool =
def
reduce_scatter_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
def
reduce_scatter_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
assert
input_
.
shape
[
0
]
%
world_size
==
0
assert
input_
.
shape
[
0
]
%
world_size
==
0
output
=
torch
.
empty
(
input_
.
shape
[
0
]
//
world_size
,
*
input_
.
shape
[
1
:],
output
=
torch
.
empty
(
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
input_
.
shape
[
0
]
//
world_size
,
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
handle
=
torch
.
distributed
.
reduce_scatter_tensor
(
output
,
input_
.
contiguous
(),
)
group
=
process_group
,
handle
=
torch
.
distributed
.
reduce_scatter_tensor
(
async_op
=
async_op
)
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
return
output
,
handle
...
@@ -102,8 +105,9 @@ all_reduce = AllReduceFunc.apply
...
@@ -102,8 +105,9 @@ all_reduce = AllReduceFunc.apply
def
sync_shared_params
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
def
sync_shared_params
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
# We want to iterate over parameters with _shared_params=True in the same order,
# We want to iterate over parameters with _shared_params=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
pamams_shared
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
pamams_shared
=
{
if
getattr
(
p
,
'_shared_params'
,
False
)}
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
"_shared_params"
,
False
)
}
for
_
,
p
in
sorted
(
pamams_shared
.
items
()):
for
_
,
p
in
sorted
(
pamams_shared
.
items
()):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# Broadcast needs src to be global rank, not group rank
# Broadcast needs src to be global rank, not group rank
...
@@ -116,8 +120,9 @@ def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
...
@@ -116,8 +120,9 @@ def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
def
allreduce_sequence_parallel_grad
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
def
allreduce_sequence_parallel_grad
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
# We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
params_seqparallel
=
{
if
getattr
(
p
,
'_sequence_parallel'
,
False
)}
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
"_sequence_parallel"
,
False
)
}
grads
=
[
p
.
grad
for
_
,
p
in
sorted
(
params_seqparallel
.
items
())]
grads
=
[
p
.
grad
for
_
,
p
in
sorted
(
params_seqparallel
.
items
())]
if
grads
:
if
grads
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
...
flash_attn/utils/generation.py
View file @
f1a73d07
# Copyright (c) 2023, Tri Dao.
# Copyright (c) 2023, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from
typing
import
Optional
,
Union
,
Sequence
,
Callable
import
gc
import
gc
import
time
import
time
from
dataclasses
import
dataclass
,
field
from
collections
import
namedtuple
from
collections
import
namedtuple
from
dataclasses
import
dataclass
,
field
from
typing
import
Callable
,
Optional
,
Sequence
,
Union
import
torch
import
torch
from
torch
import
Tensor
from
torch.profiler
import
profile
,
record_function
,
ProfilerActivity
from
einops
import
rearrange
from
einops
import
rearrange
from
torch
import
Tensor
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
from
transformers.generation
import
GreedySearchDecoderOnlyOutput
,
SampleDecoderOnlyOutput
from
transformers.generation
import
GreedySearchDecoderOnlyOutput
,
SampleDecoderOnlyOutput
...
@@ -20,6 +17,7 @@ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoder
...
@@ -20,6 +17,7 @@ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoder
class
InferenceParams
:
class
InferenceParams
:
"""Inference parameters that are passed to the main model in order
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
to efficienly calculate and store the context during inference."""
max_sequence_len
:
int
max_sequence_len
:
int
max_batch_size
:
int
max_batch_size
:
int
sequence_len_offset
:
int
=
0
sequence_len_offset
:
int
=
0
...
@@ -38,11 +36,13 @@ def modify_logits_for_top_p_filtering(logits, top_p):
...
@@ -38,11 +36,13 @@ def modify_logits_for_top_p_filtering(logits, top_p):
# First sort and calculate cumulative sum of probabilities.
# First sort and calculate cumulative sum of probabilities.
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
False
)
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
False
)
cumulative_probs
=
sorted_logits
.
softmax
(
dim
=-
1
).
cumsum
(
dim
=-
1
)
cumulative_probs
=
sorted_logits
.
softmax
(
dim
=-
1
).
cumsum
(
dim
=-
1
)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove
=
cumulative_probs
<=
(
1
-
top_p
)
sorted_indices_to_remove
=
cumulative_probs
<=
(
1
-
top_p
)
# scatter sorted tensors to original indexing
# scatter sorted tensors to original indexing
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
1
,
sorted_indices
,
sorted_indices_to_remove
)
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
logits
=
logits
.
masked_fill
(
indices_to_remove
,
float
(
'-inf'
))
1
,
sorted_indices
,
sorted_indices_to_remove
)
logits
=
logits
.
masked_fill
(
indices_to_remove
,
float
(
"-inf"
))
def
sample
(
logits
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
):
def
sample
(
logits
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
):
...
@@ -54,7 +54,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
...
@@ -54,7 +54,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
return
logits
.
argmax
(
dim
=-
1
)
return
logits
.
argmax
(
dim
=-
1
)
else
:
else
:
if
top_p
>
0.0
:
if
top_p
>
0.0
:
assert
top_p
<=
1.0
,
'
top-p should be in (0, 1].
'
assert
top_p
<=
1.0
,
"
top-p should be in (0, 1].
"
if
top_k
>
0
:
if
top_k
>
0
:
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
logits_top
,
indices
=
torch
.
topk
(
logits
,
top_k
,
dim
=-
1
)
logits_top
,
indices
=
torch
.
topk
(
logits
,
top_k
,
dim
=-
1
)
...
@@ -62,17 +62,31 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
...
@@ -62,17 +62,31 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
modify_logits_for_top_p_filtering
(
logits_top
,
top_p
)
modify_logits_for_top_p_filtering
(
logits_top
,
top_p
)
return
indices
[
return
indices
[
torch
.
arange
(
indices
.
shape
[
0
],
device
=
indices
.
device
),
torch
.
arange
(
indices
.
shape
[
0
],
device
=
indices
.
device
),
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
)
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
)
,
]
]
else
:
else
:
logits_top
=
logits
/
temperature
logits_top
=
logits
/
temperature
modify_logits_for_top_p_filtering
(
logits_top
,
top_p
)
modify_logits_for_top_p_filtering
(
logits_top
,
top_p
)
return
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
)
return
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
)
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
def
decode
(
eos_token_id
=
None
,
teacher_outputs
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
input_ids
,
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
):
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
eos_token_id
=
None
,
teacher_outputs
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
,
):
"""Decoding, either greedy or with top-k or top-p sampling.
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
...
@@ -92,19 +106,24 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
...
@@ -92,19 +106,24 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
teacher_output_len
=
teacher_outputs
.
shape
[
1
]
if
teacher_outputs
is
not
None
else
0
teacher_output_len
=
teacher_outputs
.
shape
[
1
]
if
teacher_outputs
is
not
None
else
0
if
cg
:
if
cg
:
assert
fused_ft_kernel
assert
fused_ft_kernel
if
not
hasattr
(
model
,
'
_decoding_cache
'
):
if
not
hasattr
(
model
,
"
_decoding_cache
"
):
model
.
_decoding_cache
=
None
model
.
_decoding_cache
=
None
model
.
_decoding_cache
=
update_graph_cache
(
model
.
_decoding_cache
=
update_graph_cache
(
model
,
model
.
_decoding_cache
,
batch_size
,
seqlen_og
,
max_length
,
model
,
tensor_parallel
=
tensor_parallel
model
.
_decoding_cache
,
batch_size
,
seqlen_og
,
max_length
,
tensor_parallel
=
tensor_parallel
,
)
)
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
.
max_sequence_len
=
max_length
inference_params
.
max_sequence_len
=
max_length
inference_params
.
max_batch_size
=
batch_size
inference_params
.
max_batch_size
=
batch_size
inference_params
.
sequence_len_offset
=
0
inference_params
.
sequence_len_offset
=
0
else
:
else
:
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
inference_params
=
InferenceParams
(
fused_ft_kernel
=
fused_ft_kernel
)
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
fused_ft_kernel
)
scores
=
[]
scores
=
[]
with
torch
.
inference_mode
():
with
torch
.
inference_mode
():
if
timing
:
if
timing
:
...
@@ -123,18 +142,32 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
...
@@ -123,18 +142,32 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
sequences
=
[
next_token
]
sequences
=
[
next_token
]
inference_params
.
sequence_len_offset
=
seqlen_og
inference_params
.
sequence_len_offset
=
seqlen_og
while
True
:
while
True
:
position_ids
=
torch
.
full
((
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
position_ids
=
torch
.
full
(
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
(
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
)
if
not
cg
:
if
not
cg
:
logits
=
model
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
=
position_ids
,
logits
=
model
(
inference_params
=
inference_params
,
last_token_only
=
True
).
logits
rearrange
(
next_token
,
"b -> b 1"
),
position_ids
=
position_ids
,
inference_params
=
inference_params
,
last_token_only
=
True
,
).
logits
else
:
else
:
logits
=
model
.
_decoding_cache
.
run
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
,
logits
=
model
.
_decoding_cache
.
run
(
inference_params
.
sequence_len_offset
)
rearrange
(
next_token
,
"b -> b 1"
),
position_ids
,
inference_params
.
sequence_len_offset
,
)
if
vocab_size
is
not
None
:
if
vocab_size
is
not
None
:
logits
=
logits
[...,
:
vocab_size
]
logits
=
logits
[...,
:
vocab_size
]
scores
.
append
(
logits
if
not
cg
else
logits
.
clone
())
scores
.
append
(
logits
if
not
cg
else
logits
.
clone
())
if
teacher_outputs
is
None
or
teacher_output_len
<=
inference_params
.
sequence_len_offset
+
1
:
if
(
teacher_outputs
is
None
or
teacher_output_len
<=
inference_params
.
sequence_len_offset
+
1
):
next_token
=
sample
(
logits
,
top_k
=
top_k
,
temperature
=
temperature
)
next_token
=
sample
(
logits
,
top_k
=
top_k
,
temperature
=
temperature
)
else
:
else
:
next_token
=
teacher_outputs
[:,
inference_params
.
sequence_len_offset
+
1
]
next_token
=
teacher_outputs
[:,
inference_params
.
sequence_len_offset
+
1
]
...
@@ -148,30 +181,45 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
...
@@ -148,30 +181,45 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if
tensor_parallel
>
1
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
'
Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms
'
)
print
(
f
"
Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms
"
)
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
return
output_cls
(
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
scores
=
tuple
(
scores
)
scores
=
tuple
(
scores
)
)
)
class
GenerationMixin
:
class
GenerationMixin
:
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
raise
NotImplementedError
raise
NotImplementedError
def
generate
(
self
,
input_ids
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
def
generate
(
return_dict_in_generate
=
False
,
output_scores
=
False
,
**
kwargs
):
self
,
output
=
decode
(
input_ids
,
self
,
max_length
,
top_k
=
top_k
,
top_p
=
top_p
,
input_ids
,
temperature
=
temperature
,
**
kwargs
)
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
return_dict_in_generate
=
False
,
output_scores
=
False
,
**
kwargs
,
):
output
=
decode
(
input_ids
,
self
,
max_length
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
**
kwargs
)
if
not
output_scores
:
if
not
output_scores
:
output
.
scores
=
None
output
.
scores
=
None
return
output
if
return_dict_in_generate
else
output
.
sequences
return
output
if
return_dict_in_generate
else
output
.
sequences
def
allocate_inference_cache
(
max_batch_size
,
max_seqlen
,
nheads
,
headdim
,
layers
:
Union
[
int
,
Sequence
],
def
allocate_inference_cache
(
device
,
dtype
=
torch
.
float16
):
max_batch_size
,
max_seqlen
,
nheads
,
headdim
,
layers
:
Union
[
int
,
Sequence
],
device
,
dtype
=
torch
.
float16
,
):
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
packsize
=
4
if
dtype
==
torch
.
float32
else
8
assert
headdim
%
packsize
==
0
assert
headdim
%
packsize
==
0
...
@@ -179,9 +227,13 @@ def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers
...
@@ -179,9 +227,13 @@ def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers
v_cache_shape
=
(
max_batch_size
,
nheads
,
max_seqlen
,
headdim
)
v_cache_shape
=
(
max_batch_size
,
nheads
,
max_seqlen
,
headdim
)
if
isinstance
(
layers
,
int
):
if
isinstance
(
layers
,
int
):
layers
=
range
(
layers
)
layers
=
range
(
layers
)
return
{
i
:
(
torch
.
empty
(
k_cache_shape
,
device
=
device
,
dtype
=
dtype
),
return
{
torch
.
empty
(
v_cache_shape
,
device
=
device
,
dtype
=
dtype
))
i
:
(
for
i
in
layers
}
torch
.
empty
(
k_cache_shape
,
device
=
device
,
dtype
=
dtype
),
torch
.
empty
(
v_cache_shape
,
device
=
device
,
dtype
=
dtype
),
)
for
i
in
layers
}
def
seqlen_to_seqlen_type
(
seqlen
:
int
)
->
int
:
def
seqlen_to_seqlen_type
(
seqlen
:
int
)
->
int
:
...
@@ -211,49 +263,70 @@ class DecodingCGCache:
...
@@ -211,49 +263,70 @@ class DecodingCGCache:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
update_graph_cache
(
model
,
cache
,
batch_size
,
seqlen_og
,
max_seqlen
,
tensor_parallel
=
1
,
def
update_graph_cache
(
dtype
=
None
,
n_warmups
=
2
):
model
,
cache
,
batch_size
,
seqlen_og
,
max_seqlen
,
tensor_parallel
=
1
,
dtype
=
None
,
n_warmups
=
2
):
if
cache
is
None
:
if
cache
is
None
:
cache
=
DecodingCGCache
()
cache
=
DecodingCGCache
()
param_example
=
next
(
iter
(
model
.
parameters
()))
param_example
=
next
(
iter
(
model
.
parameters
()))
device
=
param_example
.
device
device
=
param_example
.
device
if
dtype
is
None
:
if
dtype
is
None
:
dtype
=
param_example
.
dtype
dtype
=
param_example
.
dtype
if
((
device
,
dtype
)
!=
(
cache
.
device
,
cache
.
dtype
)
or
batch_size
>
cache
.
max_batch_size
if
(
or
max_seqlen
>
cache
.
max_seqlen
):
# Invalidate the cache
(
device
,
dtype
)
!=
(
cache
.
device
,
cache
.
dtype
)
or
batch_size
>
cache
.
max_batch_size
or
max_seqlen
>
cache
.
max_seqlen
):
# Invalidate the cache
cache
.
callables
=
{}
cache
.
callables
=
{}
cache
.
mempool
=
None
cache
.
mempool
=
None
cache
.
inference_params
=
None
cache
.
inference_params
=
None
gc
.
collect
()
gc
.
collect
()
cache
.
device
,
cache
.
dtype
=
device
,
dtype
cache
.
device
,
cache
.
dtype
=
device
,
dtype
cache
.
max_batch_size
,
cache
.
max_seqlen
=
batch_size
,
max_seqlen
cache
.
max_batch_size
,
cache
.
max_seqlen
=
batch_size
,
max_seqlen
if
hasattr
(
model
,
'
allocate_inference_cache
'
):
if
hasattr
(
model
,
"
allocate_inference_cache
"
):
inf_cache
=
model
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
)
inf_cache
=
model
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
)
else
:
else
:
headdim
=
getattr
(
model
.
config
,
'head_dim'
,
headdim
=
getattr
(
model
.
config
.
hidden_size
//
model
.
config
.
num_attention_heads
)
model
.
config
,
"head_dim"
,
model
.
config
.
hidden_size
//
model
.
config
.
num_attention_heads
,
)
inf_cache
=
allocate_inference_cache
(
inf_cache
=
allocate_inference_cache
(
batch_size
,
max_seqlen
,
model
.
config
.
num_attention_heads
//
tensor_parallel
,
headdim
,
batch_size
,
model
.
config
.
num_hidden_layers
,
device
,
dtype
max_seqlen
,
model
.
config
.
num_attention_heads
//
tensor_parallel
,
headdim
,
model
.
config
.
num_hidden_layers
,
device
,
dtype
,
)
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
cache
.
inference_params
=
InferenceParams
(
cache
.
inference_params
=
InferenceParams
(
max_sequence_len
=
max_seqlen
,
max_batch_size
=
batch_size
,
max_sequence_len
=
max_seqlen
,
sequence_len_offset
=
seqlen_og
,
key_value_memory_dict
=
inf_cache
,
fused_ft_kernel
=
True
,
max_batch_size
=
batch_size
,
lengths_per_sample
=
lengths_per_sample
sequence_len_offset
=
seqlen_og
,
key_value_memory_dict
=
inf_cache
,
fused_ft_kernel
=
True
,
lengths_per_sample
=
lengths_per_sample
,
)
)
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
for
s_type
in
range
(
seqlen_to_seqlen_type
(
seqlen_og
),
seqlen_to_seqlen_type
(
max_seqlen
)
+
1
):
for
s_type
in
range
(
seqlen_to_seqlen_type
(
seqlen_og
),
seqlen_to_seqlen_type
(
max_seqlen
)
+
1
):
if
(
batch_size
,
s_type
)
not
in
cache
.
callables
:
if
(
batch_size
,
s_type
)
not
in
cache
.
callables
:
max_seqlen_
=
min
(
max
(
seqlen_og
,
seqlen_type_to_max_seqlen
(
s_type
)),
max_seqlen
)
max_seqlen_
=
min
(
max
(
seqlen_og
,
seqlen_type_to_max_seqlen
(
s_type
)),
max_seqlen
)
cache
.
callables
[
batch_size
,
s_type
]
=
capture_graph
(
cache
.
callables
[
batch_size
,
s_type
]
=
capture_graph
(
model
,
cache
.
inference_params
,
batch_size
,
max_seqlen_
,
mempool
=
cache
.
mempool
,
model
,
n_warmups
=
n_warmups
cache
.
inference_params
,
batch_size
,
max_seqlen_
,
mempool
=
cache
.
mempool
,
n_warmups
=
n_warmups
,
)
)
def
dispatch
(
input_ids
,
position_ids
,
seqlen
):
def
dispatch
(
input_ids
,
position_ids
,
seqlen
):
batch_size
=
input_ids
.
shape
[
0
]
batch_size
=
input_ids
.
shape
[
0
]
return
cache
.
callables
[
batch_size
,
seqlen_to_seqlen_type
(
seqlen
)](
input_ids
,
position_ids
,
seqlen
)
return
cache
.
callables
[
batch_size
,
seqlen_to_seqlen_type
(
seqlen
)](
input_ids
,
position_ids
,
seqlen
)
cache
.
run
=
dispatch
cache
.
run
=
dispatch
cache
.
inference_params
.
sequence_len_offset
=
0
# Reset so it's not confusing
cache
.
inference_params
.
sequence_len_offset
=
0
# Reset so it's not confusing
...
@@ -275,8 +348,12 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
...
@@ -275,8 +348,12 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
s
.
wait_stream
(
torch
.
cuda
.
current_stream
())
s
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
s
):
with
torch
.
cuda
.
stream
(
s
):
for
_
in
range
(
n_warmups
):
for
_
in
range
(
n_warmups
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
logits
=
model
(
last_token_only
=
True
).
logits
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
last_token_only
=
True
,
).
logits
s
.
synchronize
()
s
.
synchronize
()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think,
# which requires that graph launch and non-captured launch to not overlap (I think,
...
@@ -288,8 +365,12 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
...
@@ -288,8 +365,12 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
# To allow capture, automatically sets a side stream as the current stream in the context
# To allow capture, automatically sets a side stream as the current stream in the context
graph
=
torch
.
cuda
.
CUDAGraph
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
mempool
):
with
torch
.
cuda
.
graph
(
graph
,
pool
=
mempool
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
logits
=
model
(
last_token_only
=
True
).
logits
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
last_token_only
=
True
,
).
logits
def
run
(
new_input_ids
,
new_position_ids
,
seqlen
):
def
run
(
new_input_ids
,
new_position_ids
,
seqlen
):
inference_params
.
lengths_per_sample
[:]
=
seqlen
inference_params
.
lengths_per_sample
[:]
=
seqlen
...
...
flash_attn/utils/pretrained.py
View file @
f1a73d07
...
@@ -3,13 +3,18 @@ from functools import partial
...
@@ -3,13 +3,18 @@ from functools import partial
import
torch
import
torch
from
safetensors.torch
import
load_file
as
safe_load_file
from
safetensors.torch
import
load_file
as
safe_load_file
from
transformers.utils
import
WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
SAFE_WEIGHTS_INDEX_NAME
from
transformers.utils
import
(
SAFE_WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
,
)
from
transformers.utils.hub
import
cached_file
,
get_checkpoint_shard_files
from
transformers.utils.hub
import
cached_file
,
get_checkpoint_shard_files
def
state_dict_from_pretrained
(
model_name
,
device
=
None
,
dtype
=
None
):
def
state_dict_from_pretrained
(
model_name
,
device
=
None
,
dtype
=
None
):
# If not fp32, then we don't want to load directly to the GPU
# If not fp32, then we don't want to load directly to the GPU
mapped_device
=
'
cpu
'
if
dtype
not
in
[
torch
.
float32
,
None
]
else
device
mapped_device
=
"
cpu
"
if
dtype
not
in
[
torch
.
float32
,
None
]
else
device
is_sharded
=
False
is_sharded
=
False
load_safe
=
False
load_safe
=
False
resolved_archive_file
=
None
resolved_archive_file
=
None
...
@@ -20,19 +25,23 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
...
@@ -20,19 +25,23 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
safe_weights_index_path
=
os
.
path
.
join
(
model_name
,
SAFE_WEIGHTS_INDEX_NAME
)
safe_weights_index_path
=
os
.
path
.
join
(
model_name
,
SAFE_WEIGHTS_INDEX_NAME
)
if
os
.
path
.
isfile
(
weights_path
):
if
os
.
path
.
isfile
(
weights_path
):
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_NAME
,
resolved_archive_file
=
cached_file
(
_raise_exceptions_for_missing_entries
=
False
)
model_name
,
WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
elif
os
.
path
.
isfile
(
weights_index_path
):
elif
os
.
path
.
isfile
(
weights_index_path
):
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_INDEX_NAME
,
resolved_archive_file
=
cached_file
(
_raise_exceptions_for_missing_entries
=
False
)
model_name
,
WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
is_sharded
=
True
is_sharded
=
True
elif
os
.
path
.
isfile
(
safe_weights_path
):
elif
os
.
path
.
isfile
(
safe_weights_path
):
resolved_archive_file
=
cached_file
(
model_name
,
SAFE_WEIGHTS_NAME
,
resolved_archive_file
=
cached_file
(
_raise_exceptions_for_missing_entries
=
False
)
model_name
,
SAFE_WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
load_safe
=
True
load_safe
=
True
elif
os
.
path
.
isfile
(
safe_weights_index_path
):
elif
os
.
path
.
isfile
(
safe_weights_index_path
):
resolved_archive_file
=
cached_file
(
model_name
,
SAFE_WEIGHTS_INDEX_NAME
,
resolved_archive_file
=
cached_file
(
_raise_exceptions_for_missing_entries
=
False
)
model_name
,
SAFE_WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
is_sharded
=
True
is_sharded
=
True
load_safe
=
True
load_safe
=
True
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment