Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
748f86f3
Unverified
Commit
748f86f3
authored
Oct 06, 2025
by
Lifu Huang
Committed by
GitHub
Oct 06, 2025
Browse files
[Bug] Fix incorrect assertion in FA4 and add UT. (#11182)
parent
73ea484a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
576 additions
and
5 deletions
+576
-5
sgl-kernel/python/sgl_kernel/flash_attn.py
sgl-kernel/python/sgl_kernel/flash_attn.py
+1
-4
sgl-kernel/tests/test_flash_attention_4.py
sgl-kernel/tests/test_flash_attention_4.py
+575
-1
No files found.
sgl-kernel/python/sgl_kernel/flash_attn.py
View file @
748f86f3
...
...
@@ -161,10 +161,7 @@ def flash_attn_with_kvcache(
k
is
None
and
v
is
None
),
"FA4 does not support updating KV cache in-place."
assert
(
rotary_cos
is
None
and
rotary_sin
is
None
and
rotary_interleaved
is
None
and
rotary_seqlens
is
None
rotary_cos
is
None
and
rotary_sin
is
None
and
rotary_seqlens
is
None
),
"FA4 does not support rotary embedding."
assert
(
cache_batch_idx
is
None
and
cache_leftpad
is
None
...
...
sgl-kernel/tests/test_flash_attention_4.py
View file @
748f86f3
...
...
@@ -10,10 +10,11 @@ import pytest
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
utils
import
is_hopper
flash_attn_varlen_func
=
partial
(
flash_attn_varlen_func
,
ver
=
4
)
flash_attn_with_kvcache
=
partial
(
flash_attn_with_kvcache
,
ver
=
4
)
def
unpad_input
(
hidden_states
,
attention_mask
,
unused_mask
=
None
):
...
...
@@ -873,5 +874,578 @@ def test_flash_attn_varlen_output(
).
abs
().
max
().
item
()
+
dv_atol
@
pytest
.
mark
.
skipif
(
is_hopper
(),
reason
=
"skip on hopper"
,
)
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"has_learnable_sink"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_learnable_sink", [False])
# @pytest.mark.parametrize("new_kv", [False, True])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [False])
# @pytest.mark.parametrize("causal", [False, True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
True
])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
False
])
# @pytest.mark.parametrize("has_rotary_seqlens", [False, True])
@
pytest
.
mark
.
parametrize
(
"has_rotary_seqlens"
,
[
False
])
# @pytest.mark.parametrize("rotary_interleaved", [False, True])
@
pytest
.
mark
.
parametrize
(
"rotary_interleaved"
,
[
True
])
# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
])
# @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128]))
@
pytest
.
mark
.
parametrize
(
"page_size"
,
[
None
,
128
])
# @pytest.mark.parametrize("page_size", [128])
# @pytest.mark.parametrize("has_leftpad", [False, True])
@
pytest
.
mark
.
parametrize
(
"has_leftpad"
,
[
False
])
# @pytest.mark.parametrize("has_batch_idx", [False, True])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
])
# @pytest.mark.parametrize("varlen_q", [False, True])
@
pytest
.
mark
.
parametrize
(
"varlen_q"
,
[
False
])
# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
64
])
# @pytest.mark.parametrize("d", [192])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
128
),
(
1
,
339
),
(
3
,
1024
),
(
64
,
800
),
(
64
,
256
),
(
3
,
799
),
(
64
,
2048
),
(
16
,
20000
),
# # (1, 128 * 1024),
# # (16, 128 * 1024),
# (128, 128),
# (256, 512), # To test appending KV with more than 1 block
# (2048, 3577), # Enough tile to test persistent scheduler
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_kvcache
(
seqlen_q
,
seqlen_k
,
d
,
varlen_q
,
has_batch_idx
,
has_leftpad
,
page_size
,
rotary_fraction
,
rotary_interleaved
,
has_rotary_seqlens
,
seqlen_new_eq_seqlen_q
,
causal
,
local
,
new_kv
,
has_learnable_sink
,
mha_type
,
dtype
,
):
if
page_size
is
not
None
and
seqlen_k
%
page_size
!=
0
:
pytest
.
skip
()
if
seqlen_q
>
seqlen_k
and
new_kv
:
pytest
.
skip
()
if
not
new_kv
and
rotary_fraction
>
0.0
:
pytest
.
skip
()
if
rotary_fraction
==
0.0
and
has_rotary_seqlens
:
pytest
.
skip
()
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
5
# batch_size = 1
batch_size_cache
=
batch_size
if
not
has_batch_idx
else
batch_size
*
2
nheads
=
6
# nheads = 1
# rotary_dim must be a multiple of 16, and must be <= d
rotary_dim
=
math
.
floor
(
int
(
rotary_fraction
*
d
)
/
16
)
*
16
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
dtype_ref
=
torch
.
bfloat16
if
dtype
==
torch
.
float8_e4m3fn
else
dtype
# dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
dv_vals
=
[
d
]
if
dtype
==
torch
.
float8_e4m3fn
:
dv_vals
=
[
d
]
# attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) else [0]
attention_chunk_vals
=
[
0
]
for
dv
,
attention_chunk
in
itertools
.
product
(
dv_vals
,
attention_chunk_vals
):
# has_qv = d == 64 and dv >= 256
has_qv
=
False
q
=
(
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype_ref
)
.
to
(
dtype
)
.
to
(
dtype_ref
)
)
if
has_qv
:
qv
=
(
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
dv
,
device
=
device
,
dtype
=
dtype_ref
)
.
to
(
dtype
)
.
to
(
dtype_ref
)
)
else
:
qv
=
None
if
varlen_q
:
query_padding_mask
=
generate_random_padding_mask
(
seqlen_q
,
batch_size
,
device
,
mode
=
"random"
)
q_unpad
,
indices_q
,
cu_seqlens_q
,
max_seqlen_q
,
*
rest
=
unpad_input
(
q
,
query_padding_mask
)
output_pad_fn
=
lambda
output_unpad
:
pad_input
(
output_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
qv_unpad
=
(
rearrange
(
qv
,
"b s ... -> (b s) ..."
)[
indices_q
]
if
has_qv
else
None
)
else
:
query_padding_mask
=
None
q_unpad
=
q
qv_unpad
=
qv
cu_seqlens_q
,
max_seqlen_q
=
None
,
None
# Put window_size after QKV randn so that window_size changes from test to test
window_size
=
(
(
None
,
None
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,)).
tolist
()
)
if
has_learnable_sink
:
learnable_sink
=
torch
.
randn
(
nheads
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
else
:
learnable_sink
=
None
seqlen_new
=
(
seqlen_q
if
seqlen_new_eq_seqlen_q
else
torch
.
randint
(
1
,
seqlen_q
+
1
,
(
1
,)).
item
()
)
cu_seqlens_k_new
=
None
key_new_padding_mask
=
None
if
new_kv
:
k
=
(
torch
.
randn
(
batch_size
,
seqlen_new
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype_ref
)
.
to
(
dtype
)
.
to
(
dtype_ref
)
)
v
=
(
torch
.
randn
(
batch_size
,
seqlen_new
,
nheads_k
,
dv
,
device
=
device
,
dtype
=
dtype_ref
)
.
to
(
dtype
)
.
to
(
dtype_ref
)
)
if
varlen_q
:
# k & v are also varlen
key_new_padding_mask
=
generate_random_padding_mask
(
seqlen_new
,
batch_size
,
device
,
mode
=
"random"
)
k_unpad
,
indices_k
,
cu_seqlens_k_new
,
*
rest
=
unpad_input
(
k
,
key_new_padding_mask
)
v_unpad
,
*
rest
=
unpad_input
(
v
,
key_new_padding_mask
)
else
:
k_unpad
,
v_unpad
=
k
,
v
else
:
k
,
v
,
k_unpad
,
v_unpad
=
None
,
None
,
None
,
None
if
page_size
is
None
:
k_cache
=
(
torch
.
randn
(
batch_size_cache
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype_ref
,
)
.
to
(
dtype
)
.
to
(
dtype_ref
)
)
v_cache
=
(
torch
.
randn
(
batch_size_cache
,
seqlen_k
,
nheads_k
,
dv
,
device
=
device
,
dtype
=
dtype_ref
,
)
.
to
(
dtype
)
.
to
(
dtype_ref
)
)
page_table
=
None
else
:
(
k_cache
,
v_cache
,
page_table
,
k_cache_paged
,
v_cache_paged
,
num_blocks
,
)
=
_generate_block_kvcache
(
seqlen_k
,
page_size
,
batch_size_cache
,
nheads_k
,
d
,
dv
,
device
,
dtype
,
dtype_ref
,
)
cache_seqlens
=
torch
.
randint
(
0
if
new_kv
else
1
,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(
(
seqlen_k
-
(
seqlen_q
if
(
causal
or
local
)
and
rotary_dim
>
1
else
seqlen_new
)
+
1
)
if
new_kv
else
(
seqlen_k
+
1
)
),
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
,
)
if
has_leftpad
:
cache_leftpad
=
torch
.
cat
(
[
(
torch
.
randint
(
0
,
cache_seqlens
[
i
].
item
(),
(
1
,),
dtype
=
torch
.
int32
,
device
=
device
,
)
if
cache_seqlens
[
i
].
item
()
>
0
else
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
device
)
)
for
i
in
range
(
batch_size
)
]
)
else
:
cache_leftpad
=
None
if
has_batch_idx
:
cache_batch_idx
=
torch
.
randperm
(
batch_size_cache
,
dtype
=
torch
.
int32
,
device
=
device
)[:
batch_size
]
else
:
cache_batch_idx
=
None
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
if
not
new_kv
:
key_padding_mask
=
arange
<
cache_seqlens_expanded
else
:
k_new_seqlens
=
(
key_new_padding_mask
.
sum
(
-
1
,
keepdims
=
True
)
if
varlen_q
else
seqlen_new
)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
k_new_seqlens
if
has_leftpad
:
key_padding_mask
=
torch
.
logical_and
(
key_padding_mask
,
arange
>=
cache_leftpad
.
unsqueeze
(
-
1
).
expand
(
-
1
,
seqlen_k
),
)
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
rotary_seqlens
=
cache_seqlens
if
not
has_rotary_seqlens
else
cache_seqlens
//
2
if
rotary_dim
>
0
:
angle
=
(
torch
.
rand
(
seqlen_k
if
page_size
is
None
else
num_blocks
*
page_size
,
rotary_dim
//
2
,
device
=
device
,
)
*
2
*
math
.
pi
)
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype_ref
).
to
(
dtype
).
to
(
dtype_ref
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype_ref
).
to
(
dtype
).
to
(
dtype_ref
)
if
causal
or
local
:
q_ro
=
apply_rotary_emb
(
q
,
cos
,
sin
,
seqlen_offsets
=
rotary_seqlens
,
interleaved
=
rotary_interleaved
,
)
else
:
q_ro
=
rearrange
(
apply_rotary_emb
(
rearrange
(
q
,
"b s h d -> b 1 (s h) d"
),
cos
,
sin
,
seqlen_offsets
=
rotary_seqlens
,
interleaved
=
rotary_interleaved
,
),
"b 1 (s h) d -> b s h d"
,
s
=
seqlen_q
,
)
# q_ro = q
k_ro
=
apply_rotary_emb
(
k
,
cos
,
sin
,
seqlen_offsets
=
rotary_seqlens
,
interleaved
=
rotary_interleaved
,
)
else
:
cos
,
sin
=
None
,
None
q_ro
,
k_ro
=
q
,
k
# k_cache[:, 64:] = -1
k_cache_ref
=
(
k_cache
if
not
has_batch_idx
else
k_cache
[
cache_batch_idx
]
).
clone
()
v_cache_ref
=
(
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
]
).
clone
()
if
new_kv
:
update_mask
=
torch
.
logical_and
(
cache_seqlens_expanded
<=
arange
,
arange
<
cache_seqlens_expanded
+
k_new_seqlens
,
)
k_to_update
=
rearrange
(
k_ro
,
"b s ... -> (b s) ..."
)
v_to_update
=
rearrange
(
v
,
"b s ... -> (b s) ..."
)
if
varlen_q
:
k_to_update
=
k_to_update
[
indices_k
]
v_to_update
=
v_to_update
[
indices_k
]
k_cache_ref
[
update_mask
]
=
k_to_update
v_cache_ref
[
update_mask
]
=
v_to_update
k_cache_rep
=
repeat
(
k_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
v_cache_rep
=
repeat
(
v_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
out_ref
,
_
=
attention_ref
(
q_ro
,
k_cache_rep
,
v_cache_rep
,
query_padding_mask
,
key_padding_mask
,
causal
=
causal
,
qv
=
qv
,
window_size
=
window_size
,
learnable_sink
=
learnable_sink
,
attention_chunk
=
attention_chunk
,
key_leftpad
=
cache_leftpad
,
)
out_pt
,
_
=
attention_ref
(
q_ro
,
k_cache_rep
,
v_cache_rep
,
query_padding_mask
,
key_padding_mask
,
causal
=
causal
,
qv
=
qv
,
window_size
=
window_size
,
learnable_sink
=
learnable_sink
,
attention_chunk
=
attention_chunk
,
upcast
=
False
,
reorder_ops
=
True
,
key_leftpad
=
cache_leftpad
,
intermediate_dtype
=
dtype
if
dtype
==
torch
.
float8_e4m3fn
else
None
,
)
q
=
q
.
to
(
dtype
)
q_unpad
=
q_unpad
.
to
(
dtype
)
if
varlen_q
else
None
k_cache
=
k_cache
.
to
(
dtype
)
v_cache
=
v_cache
.
to
(
dtype
)
k_cache_paged
=
k_cache_paged
.
to
(
dtype
)
if
page_size
is
not
None
else
None
v_cache_paged
=
v_cache_paged
.
to
(
dtype
)
if
page_size
is
not
None
else
None
k
=
k
.
to
(
dtype
)
if
k
is
not
None
else
None
v
=
v
.
to
(
dtype
)
if
v
is
not
None
else
None
k_unpad
=
k_unpad
.
to
(
dtype
)
if
k_unpad
is
not
None
else
None
v_unpad
=
v_unpad
.
to
(
dtype
)
if
v_unpad
is
not
None
else
None
qv
=
qv
.
to
(
dtype
)
if
qv
is
not
None
else
None
qv_unpad
=
qv_unpad
.
to
(
dtype
)
if
(
varlen_q
and
qv
is
not
None
)
else
None
cos
=
cos
.
to
(
dtype
)
if
cos
is
not
None
else
None
sin
=
sin
.
to
(
dtype
)
if
sin
is
not
None
else
None
k_cache_saved
=
k_cache
.
clone
()
if
page_size
is
None
else
k_cache_paged
.
clone
()
v_cache_saved
=
v_cache
.
clone
()
if
page_size
is
None
else
v_cache_paged
.
clone
()
# num_splits_vals = [1, 0]
num_splits_vals
=
[
1
]
# precompute_metadata_vals = [False, True]
precompute_metadata_vals
=
[
False
]
for
num_splits
,
precompute_metadata
in
itertools
.
product
(
num_splits_vals
,
precompute_metadata_vals
):
# if precompute_metadata:
# scheduler_metadata = get_scheduler_metadata(
# batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d,
# cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q,
# cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad,
# max_seqlen_k_new=seqlen_new, page_size=page_size,
# causal=causal, window_size=window_size, attention_chunk=attention_chunk,
# num_splits=num_splits
# )
# else:
# scheduler_metadata = None
scheduler_metadata
=
None
# Repeat to test metadata reuse
for
_
in
range
(
1
if
not
precompute_metadata
else
2
):
if
page_size
is
None
:
k_cache
.
copy_
(
k_cache_saved
)
v_cache
.
copy_
(
v_cache_saved
)
else
:
k_cache_paged
.
copy_
(
k_cache_saved
)
v_cache_paged
.
copy_
(
v_cache_saved
)
# out, lse, *rest = flash_attn_with_kvcache(
out
,
lse
,
*
rest
=
flash_attn_with_kvcache
(
q
if
not
varlen_q
else
q_unpad
,
k_cache
if
page_size
is
None
else
k_cache_paged
,
v_cache
if
page_size
is
None
else
v_cache_paged
,
# k if not new_kv or not varlen_q else k_unpad,
# v if not new_kv or not varlen_q else v_unpad,
# qv=qv if not varlen_q else qv_unpad,
# rotary_cos=cos,
# rotary_sin=sin,
cache_seqlens
=
cache_seqlens
,
# cache_batch_idx=cache_batch_idx,
# cache_leftpad=cache_leftpad,
page_table
=
page_table
,
cu_seqlens_q
=
cu_seqlens_q
,
# cu_seqlens_k_new=cu_seqlens_k_new,
# rotary_seqlens=rotary_seqlens,
causal
=
causal
,
window_size
=
window_size
,
sinks
=
learnable_sink
,
# attention_chunk=attention_chunk,
# rotary_interleaved=rotary_interleaved,
# scheduler_metadata=scheduler_metadata,
# num_splits=num_splits,
return_softmax_lse
=
True
,
)
if
varlen_q
:
out
=
output_pad_fn
(
out
)
# out = flash_attn_with_kvcache(
# q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
# )
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
# breakpoint()
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
if
new_kv
:
if
page_size
is
None
:
k_cache_select
=
(
k_cache
.
to
(
dtype_ref
)
if
not
has_batch_idx
else
k_cache
.
to
(
dtype_ref
)[
cache_batch_idx
]
)
v_cache_select
=
(
v_cache
.
to
(
dtype_ref
)
if
not
has_batch_idx
else
v_cache
.
to
(
dtype_ref
)[
cache_batch_idx
]
)
else
:
k_cache_select
=
rearrange
(
k_cache_paged
.
to
(
dtype_ref
)[
(
page_table
if
not
has_batch_idx
else
page_table
[
cache_batch_idx
]
).
flatten
()
],
"(b nblocks) block_size ... -> b (nblocks block_size) ..."
,
b
=
batch_size
,
)[:,
:
seqlen_k
].
to
(
dtype_ref
)
v_cache_select
=
rearrange
(
v_cache_paged
.
to
(
dtype_ref
)[
(
page_table
if
not
has_batch_idx
else
page_table
[
cache_batch_idx
]
).
flatten
()
],
"(b nblocks) block_size ... -> b (nblocks block_size) ..."
,
b
=
batch_size
,
)[:,
:
seqlen_k
].
to
(
dtype_ref
)
k_cache_ref
=
k_cache_ref
.
to
(
dtype
).
to
(
dtype_ref
)
v_cache_ref
=
v_cache_ref
.
to
(
dtype
).
to
(
dtype_ref
)
if
dtype
is
not
torch
.
float8_e4m3fn
:
assert
torch
.
equal
(
v_cache_select
,
v_cache_ref
)
else
:
assert
torch
.
allclose
(
v_cache_select
,
v_cache_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
# breakpoint()
# if rotary_dim == 0 and dtype is not torch.float8_e4m3fn:
if
rotary_dim
==
0
:
assert
torch
.
equal
(
k_cache_select
,
k_cache_ref
)
else
:
# if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3):
# breakpoint()
if
dtype
is
not
torch
.
float8_e4m3fn
:
assert
torch
.
allclose
(
k_cache_select
,
k_cache_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
else
:
assert
torch
.
allclose
(
k_cache_select
,
k_cache_ref
,
rtol
=
1e-1
,
atol
=
1e-1
)
mult
=
4
if
dtype
==
torch
.
float8_e4m3fn
else
2
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
mult
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
mult_mean
=
3
if
dtype
==
torch
.
float8_e4m3fn
else
1.5
assert
(
out
-
out_ref
).
abs
().
mean
().
item
()
<=
mult_mean
*
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
def
_generate_block_kvcache
(
seqlen_k
,
page_size
,
batch_size
,
nheads_k
,
d
,
dv
,
device
,
dtype
,
dtype_ref
):
num_blocks
=
math
.
ceil
(
seqlen_k
/
page_size
)
*
batch_size
*
3
k_cache_paged
=
(
torch
.
randn
(
num_blocks
,
page_size
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype_ref
)
.
to
(
dtype
)
.
to
(
dtype_ref
)
)
v_cache_paged
=
(
torch
.
randn
(
num_blocks
,
page_size
,
nheads_k
,
dv
,
device
=
device
,
dtype
=
dtype_ref
)
.
to
(
dtype
)
.
to
(
dtype_ref
)
)
page_table
=
rearrange
(
torch
.
randperm
(
num_blocks
,
dtype
=
torch
.
int32
,
device
=
device
),
"(b nblocks) -> b nblocks"
,
b
=
batch_size
,
)
k_cache
=
rearrange
(
k_cache_paged
[
page_table
.
flatten
()],
"(b nblocks) block_size ... -> b (nblocks block_size) ..."
,
b
=
batch_size
,
)[:,
:
seqlen_k
]
v_cache
=
rearrange
(
v_cache_paged
[
page_table
.
flatten
()],
"(b nblocks) block_size ... -> b (nblocks block_size) ..."
,
b
=
batch_size
,
)[:,
:
seqlen_k
]
return
k_cache
,
v_cache
,
page_table
,
k_cache_paged
,
v_cache_paged
,
num_blocks
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment