Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
40e534a7
Commit
40e534a7
authored
Jul 11, 2024
by
Tri Dao
Browse files
Implement cache_leftpad
parent
116b05f9
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
70 additions
and
12 deletions
+70
-12
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+21
-0
csrc/flash_attn/src/block_info.h
csrc/flash_attn/src/block_info.h
+5
-3
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+1
-0
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+6
-4
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+7
-2
tests/test_flash_attn.py
tests/test_flash_attn.py
+30
-3
No files found.
csrc/flash_attn/flash_api.cpp
View file @
40e534a7
...
@@ -532,6 +532,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
...
@@ -532,6 +532,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const
at
::
Tensor
&
cu_seqlens_q
,
// b+1
const
at
::
Tensor
&
cu_seqlens_q
,
// b+1
const
at
::
Tensor
&
cu_seqlens_k
,
// b+1
const
at
::
Tensor
&
cu_seqlens_k
,
// b+1
c10
::
optional
<
at
::
Tensor
>
&
seqused_k
,
// b. If given, only this many elements of each batch element's keys are used.
c10
::
optional
<
at
::
Tensor
>
&
seqused_k
,
// b. If given, only this many elements of each batch element's keys are used.
c10
::
optional
<
const
at
::
Tensor
>
&
leftpad_k_
,
// batch_size
c10
::
optional
<
at
::
Tensor
>
&
block_table_
,
// batch_size x max_num_blocks_per_seq
c10
::
optional
<
at
::
Tensor
>
&
block_table_
,
// batch_size x max_num_blocks_per_seq
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or b x num_heads
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or b x num_heads
int
max_seqlen_q
,
int
max_seqlen_q
,
...
@@ -731,6 +732,16 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
...
@@ -731,6 +732,16 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
head_size_rounded
,
p_dropout
,
/*num_splits*/
0
,
dprops
,
opts
);
head_size_rounded
,
p_dropout
,
/*num_splits*/
0
,
dprops
,
opts
);
}
}
if
(
leftpad_k_
.
has_value
())
{
auto
leftpad_k
=
leftpad_k_
.
value
();
TORCH_CHECK
(
!
paged_KV
,
"We don't support Paged KV and leftpad_k running at the same time yet"
);
TORCH_CHECK
(
leftpad_k
.
dtype
()
==
torch
::
kInt32
,
"leftpad_k must have dtype int32"
);
CHECK_DEVICE
(
leftpad_k
);
CHECK_CONTIGUOUS
(
leftpad_k
);
CHECK_SHAPE
(
leftpad_k
,
batch_size
);
params
.
leftpad_k
=
static_cast
<
int
*>
(
leftpad_k
.
data_ptr
());
}
// number of times random will be generated per thread, to offset philox counter in thc random
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
...
@@ -1279,6 +1290,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1279,6 +1290,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
cache_batch_idx_
,
// indices to index into the KV cache
c10
::
optional
<
const
at
::
Tensor
>
&
cache_batch_idx_
,
// indices to index into the KV cache
c10
::
optional
<
const
at
::
Tensor
>
&
leftpad_k_
,
// batch_size
c10
::
optional
<
at
::
Tensor
>
&
block_table_
,
// batch_size x max_num_blocks_per_seq
c10
::
optional
<
at
::
Tensor
>
&
block_table_
,
// batch_size x max_num_blocks_per_seq
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or batch_size x num_heads
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
// num_heads or batch_size x num_heads
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
...
@@ -1469,6 +1481,15 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1469,6 +1481,15 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
params
.
cu_seqlens_k
=
static_cast
<
int
*>
(
seqlens_k
.
data_ptr
());
params
.
cu_seqlens_k
=
static_cast
<
int
*>
(
seqlens_k
.
data_ptr
());
}
}
params
.
is_seqlens_k_cumulative
=
!
(
seqlens_k_
.
has_value
());
params
.
is_seqlens_k_cumulative
=
!
(
seqlens_k_
.
has_value
());
if
(
leftpad_k_
.
has_value
())
{
TORCH_CHECK
(
!
paged_KV
,
"We don't support Paged KV and leftpad_k running at the same time yet"
);
auto
leftpad_k
=
leftpad_k_
.
value
();
TORCH_CHECK
(
leftpad_k
.
dtype
()
==
torch
::
kInt32
,
"leftpad_k must have dtype int32"
);
CHECK_DEVICE
(
leftpad_k
);
CHECK_CONTIGUOUS
(
leftpad_k
);
CHECK_SHAPE
(
leftpad_k
,
batch_size
);
params
.
leftpad_k
=
static_cast
<
int
*>
(
leftpad_k
.
data_ptr
());
}
if
(
rotary_cos_
.
has_value
())
{
if
(
rotary_cos_
.
has_value
())
{
TORCH_CHECK
(
k_
.
has_value
(),
"If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"
);
TORCH_CHECK
(
k_
.
has_value
(),
"If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"
);
...
...
csrc/flash_attn/src/block_info.h
View file @
40e534a7
...
@@ -18,8 +18,9 @@ struct BlockInfo {
...
@@ -18,8 +18,9 @@ struct BlockInfo {
,
actual_seqlen_q
(
!
Varlen
||
params
.
cu_seqlens_q
==
nullptr
?
params
.
seqlen_q
:
params
.
cu_seqlens_q
[
bidb
+
1
]
-
sum_s_q
)
,
actual_seqlen_q
(
!
Varlen
||
params
.
cu_seqlens_q
==
nullptr
?
params
.
seqlen_q
:
params
.
cu_seqlens_q
[
bidb
+
1
]
-
sum_s_q
)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
,
seqlen_k_cache
(
!
Varlen
||
params
.
cu_seqlens_k
==
nullptr
?
params
.
seqlen_k
:
(
params
.
is_seqlens_k_cumulative
?
params
.
cu_seqlens_k
[
bidb
+
1
]
-
sum_s_k
:
params
.
cu_seqlens_k
[
bidb
]))
,
leftpad_k
(
params
.
leftpad_k
==
nullptr
?
0
:
params
.
leftpad_k
[
bidb
])
,
actual_seqlen_k
(
params
.
seqused_k
?
params
.
seqused_k
[
bidb
]
:
seqlen_k_cache
+
(
params
.
knew_ptr
==
nullptr
?
0
:
params
.
seqlen_knew
))
,
seqlen_k_cache
((
!
Varlen
||
params
.
cu_seqlens_k
==
nullptr
?
params
.
seqlen_k
:
(
params
.
is_seqlens_k_cumulative
?
params
.
cu_seqlens_k
[
bidb
+
1
]
-
sum_s_k
:
params
.
cu_seqlens_k
[
bidb
]))
-
leftpad_k
)
,
actual_seqlen_k
(
params
.
seqused_k
?
params
.
seqused_k
[
bidb
]
-
leftpad_k
:
seqlen_k_cache
+
(
params
.
knew_ptr
==
nullptr
?
0
:
params
.
seqlen_knew
))
{
{
}
}
...
@@ -30,13 +31,14 @@ struct BlockInfo {
...
@@ -30,13 +31,14 @@ struct BlockInfo {
template
<
typename
index_t
>
template
<
typename
index_t
>
__forceinline__
__device__
index_t
k_offset
(
const
index_t
batch_stride
,
const
index_t
row_stride
,
const
int
bidb
)
const
{
__forceinline__
__device__
index_t
k_offset
(
const
index_t
batch_stride
,
const
index_t
row_stride
,
const
int
bidb
)
const
{
return
sum_s_k
==
-
1
?
bidb
*
batch_stride
:
uint32_t
(
sum_s_k
)
*
row_stride
;
return
sum_s_k
==
-
1
?
bidb
*
batch_stride
+
leftpad_k
*
row_stride
:
uint32_t
(
sum_s_k
+
leftpad_k
)
*
row_stride
;
}
}
const
int
sum_s_q
;
const
int
sum_s_q
;
const
int
sum_s_k
;
const
int
sum_s_k
;
const
int
actual_seqlen_q
;
const
int
actual_seqlen_q
;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const
int
leftpad_k
;
const
int
seqlen_k_cache
;
const
int
seqlen_k_cache
;
const
int
actual_seqlen_k
;
const
int
actual_seqlen_k
;
};
};
...
...
csrc/flash_attn/src/flash.h
View file @
40e534a7
...
@@ -76,6 +76,7 @@ struct Flash_fwd_params : public Qkv_params {
...
@@ -76,6 +76,7 @@ struct Flash_fwd_params : public Qkv_params {
// array of length b+1 holding starting offset of each sequence.
// array of length b+1 holding starting offset of each sequence.
int
*
__restrict__
cu_seqlens_q
;
int
*
__restrict__
cu_seqlens_q
;
int
*
__restrict__
cu_seqlens_k
;
int
*
__restrict__
cu_seqlens_k
;
int
*
__restrict__
leftpad_k
;
// If provided, the actual length of each k sequence.
// If provided, the actual length of each k sequence.
int
*
__restrict__
seqused_k
;
int
*
__restrict__
seqused_k
;
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
40e534a7
...
@@ -690,7 +690,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -690,7 +690,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
const
index_t
row_offset_cossin
=
((
n_block_max
-
1
)
*
kBlockN
)
*
(
params
.
rotary_dim
/
2
);
const
index_t
row_offset_cossin
=
((
n_block_max
-
1
)
*
kBlockN
+
(
params
.
leftpad_k
==
nullptr
?
0
:
params
.
leftpad_k
[
bidb
])
)
*
(
params
.
rotary_dim
/
2
);
Tensor
gCos
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_cos_ptr
)
+
row_offset_cossin
),
Tensor
gCos
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_cos_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
/
2
>>
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
/
2
>>
{},
make_stride
(
params
.
rotary_dim
/
2
,
_1
{}));
make_stride
(
params
.
rotary_dim
/
2
,
_1
{}));
...
@@ -711,9 +711,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -711,9 +711,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// if (cute::thread(8, 0)) { print_tensor(gCos); }
// if (cute::thread(8, 0)) { print_tensor(gCos); }
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
const
index_t
row_offset_knew
=
binfo
.
k_offset
(
params
.
knew_batch_stride
,
params
.
knew_row_stride
,
bidb
)
// const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
const
index_t
row_offset_knew
=
bidb
*
params
.
knew_batch_stride
+
((
n_block_max
-
1
)
*
kBlockN
)
*
params
.
knew_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
knew_head_stride
;
+
((
n_block_max
-
1
)
*
kBlockN
)
*
params
.
knew_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
knew_head_stride
;
const
index_t
row_offset_vnew
=
binfo
.
k_offset
(
params
.
vnew_batch_stride
,
params
.
vnew_row_stride
,
bidb
)
// const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
const
index_t
row_offset_vnew
=
bidb
*
params
.
vnew_batch_stride
+
((
n_block_max
-
1
)
*
kBlockN
)
*
params
.
vnew_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
vnew_head_stride
;
+
((
n_block_max
-
1
)
*
kBlockN
)
*
params
.
vnew_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
vnew_head_stride
;
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
...
@@ -791,7 +793,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -791,7 +793,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
else
{
}
else
{
const
index_t
row_offset_cossin
=
(
binfo
.
seqlen_k_cache
+
(
Is_causal
||
Is_local
?
m_block
*
kBlockM
:
0
))
*
(
params
.
rotary_dim
/
2
);
const
index_t
row_offset_cossin
=
(
binfo
.
seqlen_k_cache
+
(
params
.
leftpad_k
==
nullptr
?
0
:
params
.
leftpad_k
[
bidb
])
+
(
Is_causal
||
Is_local
?
m_block
*
kBlockM
:
0
))
*
(
params
.
rotary_dim
/
2
);
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// We do this by setting the row stride of gCos / gSin to 0.
// We do this by setting the row stride of gCos / gSin to 0.
Tensor
gCos
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_cos_ptr
)
+
row_offset_cossin
),
Tensor
gCos
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_cos_ptr
)
+
row_offset_cossin
),
...
...
flash_attn/flash_attn_interface.py
View file @
40e534a7
...
@@ -81,7 +81,8 @@ def _flash_attn_varlen_forward(
...
@@ -81,7 +81,8 @@ def _flash_attn_varlen_forward(
softcap
,
softcap
,
alibi_slopes
,
alibi_slopes
,
return_softmax
,
return_softmax
,
block_table
,
block_table
=
None
,
leftpad_k
=
None
,
):
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
...
@@ -93,6 +94,7 @@ def _flash_attn_varlen_forward(
...
@@ -93,6 +94,7 @@ def _flash_attn_varlen_forward(
cu_seqlens_q
,
cu_seqlens_q
,
cu_seqlens_k
,
cu_seqlens_k
,
None
,
None
,
leftpad_k
,
block_table
,
block_table
,
alibi_slopes
,
alibi_slopes
,
max_seqlen_q
,
max_seqlen_q
,
...
@@ -1150,6 +1152,7 @@ def flash_attn_with_kvcache(
...
@@ -1150,6 +1152,7 @@ def flash_attn_with_kvcache(
rotary_sin
=
None
,
rotary_sin
=
None
,
cache_seqlens
:
Optional
[
Union
[(
int
,
torch
.
Tensor
)]]
=
None
,
cache_seqlens
:
Optional
[
Union
[(
int
,
torch
.
Tensor
)]]
=
None
,
cache_batch_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_batch_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_leftpad
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
...
@@ -1217,11 +1220,12 @@ def flash_attn_with_kvcache(
...
@@ -1217,11 +1220,12 @@ def flash_attn_with_kvcache(
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
KV cache.
block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache
If the indices are not distinct, and k and v are provided, the values updated in the cache
might come from any of the duplicate indices.
might come from any of the duplicate indices.
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
softmax_scale: float. The scaling of QK^T before applying softmax.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
...
@@ -1269,6 +1273,7 @@ def flash_attn_with_kvcache(
...
@@ -1269,6 +1273,7 @@ def flash_attn_with_kvcache(
rotary_cos
,
rotary_cos
,
rotary_sin
,
rotary_sin
,
cache_batch_idx
,
cache_batch_idx
,
cache_leftpad
,
block_table
,
block_table
,
alibi_slopes
,
alibi_slopes
,
None
,
None
,
...
...
tests/test_flash_attn.py
View file @
40e534a7
...
@@ -182,9 +182,14 @@ def construct_local_mask(
...
@@ -182,9 +182,14 @@ def construct_local_mask(
query_padding_mask
=
None
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
key_padding_mask
=
None
,
device
=
None
,
device
=
None
,
key_leftpad
=
None
,
):
):
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
device
,
dtype
=
torch
.
long
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
device
,
dtype
=
torch
.
long
)
if
key_leftpad
is
not
None
:
key_leftpad
=
rearrange
(
key_leftpad
,
"b -> b 1 1 1"
)
col_idx
=
repeat
(
col_idx
,
"s -> b 1 1 s"
,
b
=
key_leftpad
.
shape
[
0
])
col_idx
=
torch
.
where
(
col_idx
>=
key_leftpad
,
col_idx
-
key_leftpad
,
2
**
32
)
sk
=
(
sk
=
(
seqlen_k
seqlen_k
if
key_padding_mask
is
None
if
key_padding_mask
is
None
...
@@ -219,6 +224,7 @@ def attention_ref(
...
@@ -219,6 +224,7 @@ def attention_ref(
softcap
=
0.0
,
softcap
=
0.0
,
upcast
=
True
,
upcast
=
True
,
reorder_ops
=
False
,
reorder_ops
=
False
,
key_leftpad
=
None
,
):
):
"""
"""
Arguments:
Arguments:
...
@@ -268,6 +274,7 @@ def attention_ref(
...
@@ -268,6 +274,7 @@ def attention_ref(
query_padding_mask
,
query_padding_mask
,
key_padding_mask
,
key_padding_mask
,
q
.
device
,
q
.
device
,
key_leftpad
=
key_leftpad
,
)
)
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
if
attn_bias
is
not
None
:
if
attn_bias
is
not
None
:
...
@@ -306,6 +313,7 @@ def attention_kvpacked_ref(
...
@@ -306,6 +313,7 @@ def attention_kvpacked_ref(
softcap
=
0.0
,
softcap
=
0.0
,
upcast
=
True
,
upcast
=
True
,
reorder_ops
=
False
,
reorder_ops
=
False
,
key_leftpad
=
None
,
):
):
return
attention_ref
(
return
attention_ref
(
q
,
q
,
...
@@ -321,6 +329,7 @@ def attention_kvpacked_ref(
...
@@ -321,6 +329,7 @@ def attention_kvpacked_ref(
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
softcap
=
softcap
,
reorder_ops
=
reorder_ops
,
reorder_ops
=
reorder_ops
,
key_leftpad
=
key_leftpad
,
)
)
...
@@ -1868,9 +1877,11 @@ def test_flash_attn_splitkv(
...
@@ -1868,9 +1877,11 @@ def test_flash_attn_splitkv(
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
None
,
256
])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
None
,
256
])
# @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
# @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
# @pytest.mark.parametrize("paged_kv_block_size", [256])
# @pytest.mark.parametrize("paged_kv_block_size", [None])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_leftpad"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_batch_idx", [False])
# @pytest.mark.parametrize("has_leftpad", [True])
# @pytest.mark.parametrize("has_batch_idx", [False, True])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
128
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
...
@@ -1898,6 +1909,7 @@ def test_flash_attn_kvcache(
...
@@ -1898,6 +1909,7 @@ def test_flash_attn_kvcache(
seqlen_k
,
seqlen_k
,
d
,
d
,
has_batch_idx
,
has_batch_idx
,
has_leftpad
,
paged_kv_block_size
,
paged_kv_block_size
,
rotary_fraction
,
rotary_fraction
,
rotary_interleaved
,
rotary_interleaved
,
...
@@ -1916,6 +1928,8 @@ def test_flash_attn_kvcache(
...
@@ -1916,6 +1928,8 @@ def test_flash_attn_kvcache(
pytest
.
skip
()
pytest
.
skip
()
if
has_batch_idx
and
paged_kv_block_size
is
not
None
:
if
has_batch_idx
and
paged_kv_block_size
is
not
None
:
pytest
.
skip
()
pytest
.
skip
()
if
has_leftpad
and
paged_kv_block_size
is
not
None
:
pytest
.
skip
()
device
=
"cuda"
device
=
"cuda"
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
...
@@ -1961,9 +1975,19 @@ def test_flash_attn_kvcache(
...
@@ -1961,9 +1975,19 @@ def test_flash_attn_kvcache(
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
if
has_leftpad
:
cache_leftpad
=
torch
.
cat
([
torch
.
randint
(
0
,
cache_seqlens
[
i
].
item
(),
(
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
if
cache_seqlens
[
i
].
item
()
>
0
else
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
device
)
for
i
in
range
(
batch_size
)])
else
:
cache_leftpad
=
None
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_new
if
new_kv
else
0
)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_new
if
new_kv
else
0
)
if
has_leftpad
:
key_padding_mask
=
torch
.
logical_and
(
key_padding_mask
,
arange
>=
cache_leftpad
.
unsqueeze
(
-
1
).
expand
(
-
1
,
seqlen_k
)
)
if
has_batch_idx
:
if
has_batch_idx
:
cache_batch_idx
=
torch
.
randperm
(
batch_size_cache
,
dtype
=
torch
.
int32
,
device
=
device
)[
cache_batch_idx
=
torch
.
randperm
(
batch_size_cache
,
dtype
=
torch
.
int32
,
device
=
device
)[
:
batch_size
:
batch_size
...
@@ -2038,6 +2062,7 @@ def test_flash_attn_kvcache(
...
@@ -2038,6 +2062,7 @@ def test_flash_attn_kvcache(
rotary_sin
=
sin
,
rotary_sin
=
sin
,
cache_seqlens
=
cache_seqlens
,
cache_seqlens
=
cache_seqlens
,
cache_batch_idx
=
cache_batch_idx
,
cache_batch_idx
=
cache_batch_idx
,
cache_leftpad
=
cache_leftpad
,
block_table
=
block_table
,
block_table
=
block_table
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
...
@@ -2066,6 +2091,7 @@ def test_flash_attn_kvcache(
...
@@ -2066,6 +2091,7 @@ def test_flash_attn_kvcache(
None
,
None
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
key_leftpad
=
cache_leftpad
,
)
)
out_pt
,
_
=
attention_ref
(
out_pt
,
_
=
attention_ref
(
q_ro
,
q_ro
,
...
@@ -2080,6 +2106,7 @@ def test_flash_attn_kvcache(
...
@@ -2080,6 +2106,7 @@ def test_flash_attn_kvcache(
window_size
=
window_size
,
window_size
=
window_size
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
key_leftpad
=
cache_leftpad
,
)
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment