Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
e279bf8e
Commit
e279bf8e
authored
Oct 03, 2023
by
Tri Dao
Browse files
[Gen] Accept cache_batch_idx to index into the KV cache
parent
601b4dc4
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
17 deletions
+49
-17
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+15
-6
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+3
-0
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+3
-2
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+10
-2
tests/test_flash_attn.py
tests/test_flash_attn.py
+18
-7
No files found.
csrc/flash_attn/flash_api.cpp
View file @
e279bf8e
...
@@ -1037,13 +1037,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -1037,13 +1037,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
std
::
vector
<
at
::
Tensor
>
std
::
vector
<
at
::
Tensor
>
mha_fwd_kvcache
(
at
::
Tensor
&
q
,
// batch_size x seqlen_q x num_heads x head_size
mha_fwd_kvcache
(
at
::
Tensor
&
q
,
// batch_size x seqlen_q x num_heads x head_size
const
at
::
Tensor
&
kcache
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
kcache
,
// batch_size
_c
x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
vcache
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
vcache
,
// batch_size
_c
x seqlen_k x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
k_
,
// batch_size x seqlen_knew x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
k_
,
// batch_size x seqlen_knew x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
v_
,
// batch_size x seqlen_knew x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
v_
,
// batch_size x seqlen_knew x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
seqlens_k_
,
// batch_size
c10
::
optional
<
const
at
::
Tensor
>
&
seqlens_k_
,
// batch_size
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
c10
::
optional
<
const
at
::
Tensor
>
&
cache_batch_idx_
,
// indices to index into the KV cache
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
const
float
softmax_scale
,
const
float
softmax_scale
,
bool
is_causal
,
bool
is_causal
,
...
@@ -1084,6 +1085,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1084,6 +1085,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
const
int
head_size_og
=
sizes
[
3
];
const
int
head_size_og
=
sizes
[
3
];
const
int
seqlen_k
=
kcache
.
size
(
1
);
const
int
seqlen_k
=
kcache
.
size
(
1
);
const
int
num_heads_k
=
kcache
.
size
(
2
);
const
int
num_heads_k
=
kcache
.
size
(
2
);
const
int
batch_size_c
=
kcache
.
size
(
0
);
TORCH_CHECK
(
batch_size
>
0
,
"batch size must be postive"
);
TORCH_CHECK
(
batch_size
>
0
,
"batch size must be postive"
);
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
...
@@ -1102,8 +1104,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1102,8 +1104,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
}
}
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
kcache
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
kcache
,
batch_size
_c
,
seqlen_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
vcache
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
vcache
,
batch_size
_c
,
seqlen_k
,
num_heads_k
,
head_size_og
);
at
::
Tensor
q_padded
,
kcache_padded
,
vcache_padded
;
at
::
Tensor
q_padded
,
kcache_padded
,
vcache_padded
;
if
(
head_size_og
%
8
!=
0
)
{
if
(
head_size_og
%
8
!=
0
)
{
...
@@ -1229,6 +1231,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1229,6 +1231,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
params
.
rotary_dim
=
0
;
params
.
rotary_dim
=
0
;
}
}
if
(
cache_batch_idx_
.
has_value
())
{
auto
cache_batch_idx
=
cache_batch_idx_
.
value
();
CHECK_DEVICE
(
cache_batch_idx
);
CHECK_CONTIGUOUS
(
cache_batch_idx
);
TORCH_CHECK
(
cache_batch_idx
.
scalar_type
()
==
torch
::
kInt32
,
"cache_batch_idx must have dtype int32"
);
params
.
cache_batch_idx
=
reinterpret_cast
<
int
*>
(
cache_batch_idx
.
data_ptr
());
}
// This needs to match with run_mha_fwd_splitkv_dispatch
// This needs to match with run_mha_fwd_splitkv_dispatch
const
int
block_n
=
head_size
<=
64
?
256
:
(
head_size
<=
128
?
128
:
64
);
const
int
block_n
=
head_size
<=
64
?
256
:
(
head_size
<=
128
?
128
:
64
);
const
int
num_n_blocks
=
(
seqlen_k
+
block_n
-
1
)
/
block_n
;
const
int
num_n_blocks
=
(
seqlen_k
+
block_n
-
1
)
/
block_n
;
...
@@ -1248,8 +1257,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1248,8 +1257,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
}
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// Only split kernel supports appending to KV cache
// Only split kernel supports appending to KV cache
, or indexing to the cache with cache_batch_idx
run_mha_fwd
(
params
,
stream
,
/*force_split_kernel=*/
k_
.
has_value
());
run_mha_fwd
(
params
,
stream
,
/*force_split_kernel=*/
k_
.
has_value
()
||
cache_batch_idx_
.
has_value
()
);
if
(
head_size_og
%
8
!=
0
)
{
if
(
head_size_og
%
8
!=
0
)
{
out
=
out
.
index
({
"..."
,
torch
::
indexing
::
Slice
(
torch
::
indexing
::
None
,
head_size_og
)});
out
=
out
.
index
({
"..."
,
torch
::
indexing
::
Slice
(
torch
::
indexing
::
None
,
head_size_og
)});
...
...
csrc/flash_attn/src/flash.h
View file @
e279bf8e
...
@@ -95,6 +95,9 @@ struct Flash_fwd_params : public Qkv_params {
...
@@ -95,6 +95,9 @@ struct Flash_fwd_params : public Qkv_params {
void
*
__restrict__
rotary_cos_ptr
;
void
*
__restrict__
rotary_cos_ptr
;
void
*
__restrict__
rotary_sin_ptr
;
void
*
__restrict__
rotary_sin_ptr
;
// The indices to index into the KV cache.
int
*
__restrict__
cache_batch_idx
;
// The dropout probability (probability of keeping an activation).
// The dropout probability (probability of keeping an activation).
float
p_dropout
;
float
p_dropout
;
// uint32_t p_dropout_in_uint;
// uint32_t p_dropout_in_uint;
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
e279bf8e
...
@@ -668,9 +668,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -668,9 +668,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
const
index_t
row_offset_q
=
binfo
.
q_offset
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)
const
index_t
row_offset_q
=
binfo
.
q_offset
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
// We move K and V to the last block.
// We move K and V to the last block.
const
index_t
row_offset_k
=
binfo
.
k_offset
(
params
.
k_batch_stride
,
params
.
k_row_stride
,
bidb
)
const
int
bidb_cache
=
params
.
cache_batch_idx
==
nullptr
?
bidb
:
params
.
cache_batch_idx
[
bidb
];
const
index_t
row_offset_k
=
binfo
.
k_offset
(
params
.
k_batch_stride
,
params
.
k_row_stride
,
bidb_cache
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
const
index_t
row_offset_v
=
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb
)
const
index_t
row_offset_v
=
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb
_cache
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
...
...
flash_attn/flash_attn_interface.py
View file @
e279bf8e
...
@@ -928,6 +928,7 @@ def flash_attn_with_kvcache(
...
@@ -928,6 +928,7 @@ def flash_attn_with_kvcache(
rotary_cos
=
None
,
rotary_cos
=
None
,
rotary_sin
=
None
,
rotary_sin
=
None
,
cache_seqlens
:
Optional
[
Union
[(
int
,
torch
.
Tensor
)]]
=
None
,
cache_seqlens
:
Optional
[
Union
[(
int
,
torch
.
Tensor
)]]
=
None
,
cache_batch_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
...
@@ -978,8 +979,8 @@ def flash_attn_with_kvcache(
...
@@ -978,8 +979,8 @@ def flash_attn_with_kvcache(
Arguments:
Arguments:
q: (batch_size, seqlen, nheads, headdim)
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size, seqlen_cache, nheads_k, headdim)
k_cache: (batch_size
_cache
, seqlen_cache, nheads_k, headdim)
v_cache: (batch_size, seqlen_cache, nheads_k, headdim)
v_cache: (batch_size
_cache
, seqlen_cache, nheads_k, headdim)
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens.
k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
...
@@ -988,6 +989,10 @@ def flash_attn_with_kvcache(
...
@@ -988,6 +989,10 @@ def flash_attn_with_kvcache(
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
KV cache.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache
might come from any of the duplicate indices.
softmax_scale: float. The scaling of QK^T before applying softmax.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
...
@@ -1014,6 +1019,8 @@ def flash_attn_with_kvcache(
...
@@ -1014,6 +1019,8 @@ def flash_attn_with_kvcache(
cache_seqlens
=
torch
.
full
(
cache_seqlens
=
torch
.
full
(
(
k_cache
.
shape
[
0
],),
cache_seqlens
,
dtype
=
torch
.
int32
,
device
=
k_cache
.
device
(
k_cache
.
shape
[
0
],),
cache_seqlens
,
dtype
=
torch
.
int32
,
device
=
k_cache
.
device
)
)
cache_seqlens
=
maybe_contiguous
(
cache_seqlens
)
cache_batch_idx
=
maybe_contiguous
(
cache_batch_idx
)
out
,
softmax_lse
=
flash_attn_cuda
.
fwd_kvcache
(
out
,
softmax_lse
=
flash_attn_cuda
.
fwd_kvcache
(
q
,
q
,
k_cache
,
k_cache
,
...
@@ -1023,6 +1030,7 @@ def flash_attn_with_kvcache(
...
@@ -1023,6 +1030,7 @@ def flash_attn_with_kvcache(
cache_seqlens
,
cache_seqlens
,
rotary_cos
,
rotary_cos
,
rotary_sin
,
rotary_sin
,
cache_batch_idx
,
None
,
None
,
softmax_scale
,
softmax_scale
,
causal
,
causal
,
...
...
tests/test_flash_attn.py
View file @
e279bf8e
...
@@ -1668,7 +1668,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
...
@@ -1668,7 +1668,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
# @pytest.mark.parametrize("new_kv", [True])
# @pytest.mark.parametrize("new_kv", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [
Tru
e])
# @pytest.mark.parametrize("local", [
Fals
e])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
...
@@ -1677,6 +1677,8 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
...
@@ -1677,6 +1677,8 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
# @pytest.mark.parametrize("rotary_interleaved", [False])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_batch_idx", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
96
,
128
,
160
,
192
,
224
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
96
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
...
@@ -1703,6 +1705,7 @@ def test_flash_attn_kvcache(
...
@@ -1703,6 +1705,7 @@ def test_flash_attn_kvcache(
seqlen_q
,
seqlen_q
,
seqlen_k
,
seqlen_k
,
d
,
d
,
has_batch_idx
,
rotary_fraction
,
rotary_fraction
,
rotary_interleaved
,
rotary_interleaved
,
seqlen_new_eq_seqlen_q
,
seqlen_new_eq_seqlen_q
,
...
@@ -1721,6 +1724,7 @@ def test_flash_attn_kvcache(
...
@@ -1721,6 +1724,7 @@ def test_flash_attn_kvcache(
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
batch_size
=
2
batch_size_cache
=
batch_size
if
not
has_batch_idx
else
batch_size
*
2
nheads
=
6
nheads
=
6
# rotary_dim must be a multiple of 16, and must be <= d
# rotary_dim must be a multiple of 16, and must be <= d
rotary_dim
=
math
.
floor
(
int
(
rotary_fraction
*
d
)
/
16
)
*
16
rotary_dim
=
math
.
floor
(
int
(
rotary_fraction
*
d
)
/
16
)
*
16
...
@@ -1734,8 +1738,8 @@ def test_flash_attn_kvcache(
...
@@ -1734,8 +1738,8 @@ def test_flash_attn_kvcache(
v
=
torch
.
randn
(
batch_size
,
seqlen_new
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
seqlen_new
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
else
:
else
:
k
,
v
=
None
,
None
k
,
v
=
None
,
None
k_cache
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
k_cache
=
torch
.
randn
(
batch_size
_cache
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v_cache
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v_cache
=
torch
.
randn
(
batch_size
_cache
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
cache_seqlens
=
torch
.
randint
(
cache_seqlens
=
torch
.
randint
(
0
,
0
,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
...
@@ -1746,6 +1750,10 @@ def test_flash_attn_kvcache(
...
@@ -1746,6 +1750,10 @@ def test_flash_attn_kvcache(
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
if
has_batch_idx
:
cache_batch_idx
=
torch
.
randperm
(
batch_size_cache
,
dtype
=
torch
.
int32
,
device
=
device
)[:
batch_size
]
else
:
cache_batch_idx
=
None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
if
rotary_dim
>
0
:
if
rotary_dim
>
0
:
angle
=
torch
.
rand
(
seqlen_k
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
angle
=
torch
.
rand
(
seqlen_k
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
...
@@ -1775,8 +1783,8 @@ def test_flash_attn_kvcache(
...
@@ -1775,8 +1783,8 @@ def test_flash_attn_kvcache(
cos
,
sin
=
None
,
None
cos
,
sin
=
None
,
None
q_ro
,
k_ro
=
q
,
k
q_ro
,
k_ro
=
q
,
k
# k_cache[:, 64:] = -1
# k_cache[:, 64:] = -1
k_cache_ref
=
k_cache
.
clone
()
k_cache_ref
=
(
k_cache
if
not
has_batch_idx
else
k_cache
[
cache_batch_idx
])
.
clone
()
v_cache_ref
=
v_cache
.
clone
()
v_cache_ref
=
(
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
])
.
clone
()
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
if
new_kv
:
if
new_kv
:
...
@@ -1796,6 +1804,7 @@ def test_flash_attn_kvcache(
...
@@ -1796,6 +1804,7 @@ def test_flash_attn_kvcache(
cos
,
cos
,
sin
,
sin
,
cache_seqlens
,
cache_seqlens
,
cache_batch_idx
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
rotary_interleaved
=
rotary_interleaved
,
rotary_interleaved
=
rotary_interleaved
,
...
@@ -1844,8 +1853,10 @@ def test_flash_attn_kvcache(
...
@@ -1844,8 +1853,10 @@ def test_flash_attn_kvcache(
# Check that FlashAttention's numerical error is at most twice the numerical error
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
# of a Pytorch implementation.
if
new_kv
:
if
new_kv
:
assert
torch
.
allclose
(
k_cache
,
k_cache_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
k_cache_select
=
k_cache
if
not
has_batch_idx
else
k_cache
[
cache_batch_idx
]
assert
torch
.
equal
(
v_cache
,
v_cache_ref
)
v_cache_select
=
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
]
assert
torch
.
allclose
(
k_cache_select
,
k_cache_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
assert
torch
.
equal
(
v_cache_select
,
v_cache_ref
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
3
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
3
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment