Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
635f159e
Unverified
Commit
635f159e
authored
Apr 16, 2023
by
Tri Dao
Committed by
GitHub
Apr 16, 2023
Browse files
Merge pull request #166 from ksivaman/enable_cuda_graph_capture
Enable CUDA graph capture
parents
221a39fd
45567a25
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
39 additions
and
39 deletions
+39
-39
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+13
-3
csrc/flash_attn/src/fmha.h
csrc/flash_attn/src/fmha.h
+2
-0
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
+6
-4
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+6
-0
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+12
-32
No files found.
csrc/flash_attn/fmha_api.cpp
View file @
635f159e
...
...
@@ -310,6 +310,10 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t
counter_offset
=
launch_params
.
params
.
b
*
launch_params
.
params
.
h
*
32
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
auto
rng_state
=
torch
::
empty
({
2
},
options
.
dtype
(
torch
::
kInt64
));
// Forward kernel will populate memory with the seed and offset.
launch_params
.
params
.
rng_state
=
reinterpret_cast
<
uint64_t
*>
(
rng_state
.
data_ptr
());
if
(
is_dropout
)
{
// See Note [Acquire lock when using random generators]
...
...
@@ -320,6 +324,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
run_fmha_fwd
(
launch_params
);
std
::
vector
<
at
::
Tensor
>
result
=
{
softmax_lse
};
result
.
push_back
(
rng_state
);
if
(
return_softmax
)
{
result
.
push_back
(
s
);}
return
result
;
}
...
...
@@ -353,7 +358,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const
bool
zero_tensors
,
const
bool
is_causal
,
const
int
num_splits
,
c10
::
optional
<
at
::
Generator
>
gen_
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Tensor
>
&
rng_state
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm75
=
dprops
->
major
==
7
&&
dprops
->
minor
==
5
;
...
...
@@ -488,11 +494,15 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t
counter_offset
=
params
.
b
*
params
.
h
*
32
;
if
(
is_dropout
)
{
if
(
rng_state
.
has_value
()
)
{
params
.
rng_state
=
reinterpret_cast
<
uint64_t
*>
(
rng_state
.
value
().
data_ptr
());
}
else
if
(
is_dropout
)
{
// See Note [Acquire lock when using random generators]
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
params
.
philox_args
=
gen
->
philox_cuda_state
(
counter_offset
);
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
params
.
rng_state
[
0
]
=
std
::
get
<
0
>
(
seeds
);
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seeds
);
}
launch
(
params
,
stream
,
/*configure=*/
false
);
...
...
csrc/flash_attn/src/fmha.h
View file @
635f159e
...
...
@@ -125,6 +125,8 @@ struct FMHA_fprop_params : public Qkv_params {
// Random state.
at
::
PhiloxCudaState
philox_args
;
// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t
*
rng_state
;
bool
is_bf16
;
bool
is_causal
;
...
...
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
View file @
635f159e
...
...
@@ -794,8 +794,9 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) {
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
0
,
std
::
get
<
1
>
(
seeds
)
+
(
bidb
*
params
.
h
+
bidh
)
*
32
+
tidx
%
32
);
auto
seed
=
params
.
rng_state
[
0
];
auto
offset
=
params
.
rng_state
[
1
];
Philox
ph
(
seed
,
0
,
offset
+
(
bidb
*
params
.
h
+
bidh
)
*
32
+
tidx
%
32
);
if
(
loop_steps
==
1
)
{
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
true
>
(
params
,
ph
,
0
);
...
...
@@ -827,8 +828,9 @@ inline __device__ void compute_dq_dk_dv_seqparallel(const Params ¶ms) {
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
0
,
std
::
get
<
1
>
(
seeds
)
+
(
bidb
*
params
.
h
+
bidh
)
*
32
+
tidx
%
32
);
auto
seed
=
params
.
rng_state
[
0
];
auto
offset
=
params
.
rng_state
[
1
];
Philox
ph
(
seed
,
0
,
offset
+
(
bidb
*
params
.
h
+
bidh
)
*
32
+
tidx
%
32
);
int
loop_step_idx
=
blockIdx
.
z
;
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
false
,
false
,
/*Seq_parallel=*/
true
>
(
params
,
ph
,
loop_step_idx
);
...
...
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
635f159e
...
...
@@ -667,6 +667,8 @@ inline __device__ void device_1xN_loop(const Params ¶ms) {
const
int
bidb
=
blockIdx
.
x
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
y
;
// The block index.
const
int
bidx
=
gridDim
.
x
*
bidh
+
bidb
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
...
...
@@ -678,6 +680,10 @@ inline __device__ void device_1xN_loop(const Params ¶ms) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern.
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
if
(
bidx
==
0
&&
tidx
==
0
)
{
params
.
rng_state
[
0
]
=
std
::
get
<
0
>
(
seeds
);
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seeds
);
}
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
0
,
std
::
get
<
1
>
(
seeds
)
+
(
bidb
*
params
.
h
+
bidh
)
*
32
+
tidx
%
32
);
constexpr
int
M
=
Kernel_traits
::
Cta_tile_p
::
M
;
const
int
STEPS
=
(
params
.
seqlen_q
+
M
-
1
)
/
M
;
...
...
flash_attn/flash_attn_interface.py
View file @
635f159e
...
...
@@ -18,19 +18,19 @@ def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking.
Don't change it unless you know what you're doing.
"""
softmax_lse
,
*
rest
=
flash_attn_cuda
.
fwd
(
softmax_lse
,
rng_state
,
*
rest
=
flash_attn_cuda
.
fwd
(
q
,
k
,
v
,
out
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
False
,
causal
,
return_softmax
,
num_splits
,
generator
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
S_dmask
=
rest
[
0
]
if
return_softmax
else
None
return
out
,
softmax_lse
,
S_dmask
return
out
,
softmax_lse
,
rng_state
,
S_dmask
def
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
num_splits
=
0
,
generator
=
None
):
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
rng_state
=
None
,
num_splits
=
0
,
generator
=
None
):
"""
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
...
...
@@ -41,7 +41,8 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
dout
=
dout
.
contiguous
()
# CUDA code assumes that dout is contiguous
_
,
_
,
_
,
softmax_d
=
flash_attn_cuda
.
bwd
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
False
,
causal
,
num_splits
,
generator
)
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
False
,
causal
,
num_splits
,
generator
,
rng_state
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return
dq
,
dk
,
dv
,
softmax_d
...
...
@@ -52,11 +53,9 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
,
deterministic
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
softmax_lse
,
S_dmask
=
_flash_attn_forward
(
out
,
softmax_lse
,
rng_state
,
S_dmask
=
_flash_attn_forward
(
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
torch
.
empty_like
(
qkv
[:,
0
]),
cu_seqlens
,
cu_seqlens
,
max_seqlen
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
...
...
@@ -72,18 +71,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
dout
,
*
args
):
qkv
,
out
,
softmax_lse
,
cu_seqlens
,
rng_state
=
ctx
.
saved_tensors
if
rng_state
is
not
None
:
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
dqkv
=
torch
.
empty_like
(
qkv
)
_flash_attn_backward
(
dout
,
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
out
,
softmax_lse
,
dqkv
[:,
0
],
dqkv
[:,
1
],
dqkv
[:,
2
],
cu_seqlens
,
cu_seqlens
,
ctx
.
max_seqlen
,
ctx
.
max_seqlen
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
num_splits
=
1
if
ctx
.
deterministic
else
0
,
rng_state
=
rng_state
,
num_splits
=
1
if
ctx
.
deterministic
else
0
,
)
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
...
...
@@ -92,11 +86,9 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
@
staticmethod
def
forward
(
ctx
,
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
,
deterministic
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
softmax_lse
,
S_dmask
=
_flash_attn_forward
(
out
,
softmax_lse
,
rng_state
,
S_dmask
=
_flash_attn_forward
(
q
,
kv
[:,
0
],
kv
[:,
1
],
torch
.
empty_like
(
q
),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
)
...
...
@@ -112,19 +104,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
dout
,
*
args
):
q
,
kv
,
out
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
=
ctx
.
saved_tensors
if
rng_state
is
not
None
:
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
dq
=
torch
.
empty_like
(
q
)
dkv
=
torch
.
empty_like
(
kv
)
_flash_attn_backward
(
dout
,
q
,
kv
[:,
0
],
kv
[:,
1
],
out
,
softmax_lse
,
dq
,
dkv
[:,
0
],
dkv
[:,
1
],
cu_seqlens_q
,
cu_seqlens_k
,
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
num_splits
=
1
if
ctx
.
deterministic
else
0
,
rng_state
=
rng_state
,
num_splits
=
1
if
ctx
.
deterministic
else
0
,
)
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
...
...
@@ -133,11 +120,9 @@ class FlashAttnFunc(torch.autograd.Function):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
,
deterministic
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
softmax_lse
,
S_dmask
=
_flash_attn_forward
(
out
,
softmax_lse
,
rng_state
,
S_dmask
=
_flash_attn_forward
(
q
,
k
,
v
,
torch
.
empty_like
(
q
),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
)
...
...
@@ -153,17 +138,12 @@ class FlashAttnFunc(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
dout
,
*
args
):
q
,
k
,
v
,
out
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
=
ctx
.
saved_tensors
if
rng_state
is
not
None
:
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
dq
,
dk
,
dv
=
torch
.
empty_like
(
q
),
torch
.
empty_like
(
k
),
torch
.
empty_like
(
v
)
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
num_splits
=
1
if
ctx
.
deterministic
else
0
,
rng_state
=
rng_state
,
num_splits
=
1
if
ctx
.
deterministic
else
0
,
)
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment