Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ce3e7280
Unverified
Commit
ce3e7280
authored
Nov 27, 2023
by
Jeremy Reizenstein
Committed by
GitHub
Nov 27, 2023
Browse files
Allow varlen_fwd to take optional seqused_k (#647)
Co-authored-by:
bottler
<
bottler@users.noreply.github.com
>
parent
23b77c81
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
1 deletion
+19
-1
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+14
-0
csrc/flash_attn/src/block_info.h
csrc/flash_attn/src/block_info.h
+1
-1
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+3
-0
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+1
-0
No files found.
csrc/flash_attn/flash_api.cpp
View file @
ce3e7280
...
...
@@ -36,6 +36,7 @@ void set_params_fprop(Flash_fwd_params ¶ms,
at
::
Tensor
out
,
void
*
cu_seqlens_q_d
,
void
*
cu_seqlens_k_d
,
void
*
seqused_k
,
void
*
p_d
,
void
*
softmax_lse_d
,
float
p_dropout
,
...
...
@@ -72,6 +73,7 @@ void set_params_fprop(Flash_fwd_params ¶ms,
params
.
cu_seqlens_q
=
static_cast
<
int
*>
(
cu_seqlens_q_d
);
params
.
cu_seqlens_k
=
static_cast
<
int
*>
(
cu_seqlens_k_d
);
params
.
seqused_k
=
static_cast
<
int
*>
(
seqused_k
);
// P = softmax(QK^T)
params
.
p_ptr
=
p_d
;
...
...
@@ -156,6 +158,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
cu_seqlens_q_d
,
cu_seqlens_k_d
,
nullptr
,
nullptr
,
softmax_lse_d
,
p_dropout
,
softmax_scale
,
...
...
@@ -363,6 +366,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
q_padded
,
k_padded
,
v_padded
,
out
,
/*cu_seqlens_q_d=*/
nullptr
,
/*cu_seqlens_k_d=*/
nullptr
,
/*seqused_k=*/
nullptr
,
return_softmax
?
p
.
data_ptr
()
:
nullptr
,
softmax_lse
.
data_ptr
(),
p_dropout
,
...
...
@@ -436,6 +440,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
cu_seqlens_q
,
// b+1
const
at
::
Tensor
&
cu_seqlens_k
,
// b+1
c10
::
optional
<
at
::
Tensor
>
&
seqused_k
,
// b. If given, only this many elements of each batch element's keys are used.
const
int
max_seqlen_q
,
const
int
max_seqlen_k
,
const
float
p_dropout
,
...
...
@@ -494,6 +499,13 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
CHECK_SHAPE
(
v
,
total_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
cu_seqlens_q
,
batch_size
+
1
);
CHECK_SHAPE
(
cu_seqlens_k
,
batch_size
+
1
);
if
(
seqused_k
.
has_value
()){
auto
seqused_k_
=
seqused_k
.
value
();
TORCH_CHECK
(
seqused_k_
.
dtype
()
==
torch
::
kInt32
,
"seqused_k must have dtype int32"
);
TORCH_CHECK
(
seqused_k_
.
is_cuda
(),
"seqused_k must be on CUDA device"
);
TORCH_CHECK
(
seqused_k_
.
is_contiguous
(),
"seqused_k must be contiguous"
);
CHECK_SHAPE
(
seqused_k_
,
batch_size
);
}
at
::
Tensor
q_padded
,
k_padded
,
v_padded
;
if
(
head_size_og
%
8
!=
0
)
{
...
...
@@ -554,6 +566,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
q_padded
,
k_padded
,
v_padded
,
out
,
cu_seqlens_q
.
data_ptr
(),
cu_seqlens_k
.
data_ptr
(),
seqused_k
.
has_value
()
?
seqused_k
.
value
().
data_ptr
()
:
nullptr
,
return_softmax
?
p
.
data_ptr
()
:
nullptr
,
softmax_lse
.
data_ptr
(),
p_dropout
,
...
...
@@ -1167,6 +1180,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
q_padded
,
kcache_padded
,
vcache_padded
,
out
,
/*cu_seqlens_q_d=*/
nullptr
,
/*cu_seqlens_k_d=*/
nullptr
,
/*seqused_k=*/
nullptr
,
/*p_ptr=*/
nullptr
,
softmax_lse
.
data_ptr
(),
/*p_dropout=*/
0.
f
,
...
...
csrc/flash_attn/src/block_info.h
View file @
ce3e7280
...
...
@@ -19,7 +19,7 @@ struct BlockInfo {
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
,
seqlen_k_cache
(
!
Varlen
||
params
.
cu_seqlens_k
==
nullptr
?
params
.
seqlen_k
:
(
params
.
is_seqlens_k_cumulative
?
params
.
cu_seqlens_k
[
bidb
+
1
]
-
sum_s_k
:
params
.
cu_seqlens_k
[
bidb
]))
,
actual_seqlen_k
(
seqlen_k_cache
+
(
params
.
knew_ptr
==
nullptr
?
0
:
params
.
seqlen_knew
))
,
actual_seqlen_k
(
params
.
seqused_k
?
params
.
seqused_k
[
bidb
]
:
seqlen_k_cache
+
(
params
.
knew_ptr
==
nullptr
?
0
:
params
.
seqlen_knew
))
{
}
...
...
csrc/flash_attn/src/flash.h
View file @
ce3e7280
...
...
@@ -77,6 +77,9 @@ struct Flash_fwd_params : public Qkv_params {
int
*
__restrict__
cu_seqlens_q
;
int
*
__restrict__
cu_seqlens_k
;
// If provided, the actual length of each k sequence.
int
*
__restrict__
seqused_k
;
int
*
__restrict__
blockmask
;
// The K_new and V_new matrices.
...
...
flash_attn/flash_attn_interface.py
View file @
ce3e7280
...
...
@@ -83,6 +83,7 @@ def _flash_attn_varlen_forward(
None
,
cu_seqlens_q
,
cu_seqlens_k
,
None
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment