Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
flash-attention
Commits
518a5f4d
Commit
518a5f4d
authored
Jun 09, 2026
by
hly
Browse files
import aicc-master-dev
parent
c2a1b310
Changes
131
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1542 additions
and
271 deletions
+1542
-271
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdim256_prefix_prefill_bf16.cpp
.../src/target/flash_fp8_fwd_hdim256_prefix_prefill_bf16.cpp
+12
-0
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdim256_prefix_prefill_fp16.cpp
.../src/target/flash_fp8_fwd_hdim256_prefix_prefill_fp16.cpp
+12
-0
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdimqk192_hdimv128_prefix_prefill_bf16.cpp
.../flash_fp8_fwd_hdimqk192_hdimv128_prefix_prefill_bf16.cpp
+12
-0
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdimqk192_hdimv128_prefix_prefill_fp16.cpp
.../flash_fp8_fwd_hdimqk192_hdimv128_prefix_prefill_fp16.cpp
+12
-0
csrc/flash_attn_hg/src/target/flash_varlen_fwd_permute_bhsd2bshd_hdim128.cpp
...src/target/flash_varlen_fwd_permute_bhsd2bshd_hdim128.cpp
+13
-20
csrc/flash_attn_hg/src/target/flash_varlen_fwd_permute_bshd2bhsd_hdim128.cpp
...src/target/flash_varlen_fwd_permute_bshd2bhsd_hdim128.cpp
+30
-44
flash_attn/__init__.py
flash_attn/__init__.py
+2
-0
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+237
-117
setup.py
setup.py
+140
-30
tests/prefix_decode_sglang_decode.py
tests/prefix_decode_sglang_decode.py
+457
-0
tests/test_unified_attn.py
tests/test_unified_attn.py
+615
-60
No files found.
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdim256_prefix_prefill_bf16.cpp
0 → 100644
View file @
518a5f4d
// Copyright (c) 2025, Xin Zhou.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template
<
>
void
run_fp8_mha_fwd_prefix_prefill_
<
BFloat16
,
256
,
256
>
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
#ifdef BUILD_FA_FWD
run_fp8_flash_fwd_prefix_prefill
<
BFloat16
,
256
,
256
>
(
params
,
stream
);
#endif
}
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdim256_prefix_prefill_fp16.cpp
0 → 100644
View file @
518a5f4d
// Copyright (c) 2025, Xin Zhou.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template
<
>
void
run_fp8_mha_fwd_prefix_prefill_
<
Float16
,
256
,
256
>
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
#ifdef BUILD_FA_FWD
run_fp8_flash_fwd_prefix_prefill
<
Float16
,
256
,
256
>
(
params
,
stream
);
#endif
}
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdimqk192_hdimv128_prefix_prefill_bf16.cpp
0 → 100644
View file @
518a5f4d
// Copyright (c) 2025, Wenjian Zhang.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template
<
>
void
run_fp8_mha_fwd_prefix_prefill_
<
BFloat16
,
192
,
128
>
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
#ifdef BUILD_FA_FWD
run_fp8_flash_fwd_prefix_prefill
<
BFloat16
,
192
,
128
>
(
params
,
stream
);
#endif
}
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdimqk192_hdimv128_prefix_prefill_fp16.cpp
0 → 100644
View file @
518a5f4d
// Copyright (c) 2025, Wenjian Zhang.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template
<
>
void
run_fp8_mha_fwd_prefix_prefill_
<
Float16
,
192
,
128
>
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
#ifdef BUILD_FA_FWD
run_fp8_flash_fwd_prefix_prefill
<
Float16
,
192
,
128
>
(
params
,
stream
);
#endif
}
csrc/flash_attn_hg/src/target/flash_varlen_fwd_permute_bhsd2bshd_hdim128.cpp
View file @
518a5f4d
#ifdef BUILD_FA_PERMUTE
#include <hip/hip_runtime.h>
#include "../../include/intrinsic.h"
#include "../flash_fwd_permute_hdim128.h"
template
<
>
...
...
@@ -140,28 +139,22 @@ __global__ void flash_fwd_varlen_permute_bhsd2bshd<128, 4, 32>(
int32_t
block_offset
=
seqlen_limit
*
kHeadDim
;
int32_t
thread_offset
=
lane_id_col
*
8
;
// 一次读取 4x128 的 Half 到 LDS
#if defined(__gfx936__) || defined(__gfx938__)
{
auto
*
lds_ptr
=
(
__attribute__
((
address_space
(
3
)))
int
*
)(
reinterpret_cast
<
size_t
>
(
lds
)
+
static_cast
<
size_t
>
(
lane_id
*
4
)
*
sizeof
(
float
));
__builtin_hcu_raw_buffer_load_lds
(
read_buffer
,
lds_ptr
,
16
,
(
block_offset
+
thread_offset
)
<<
1
,
/* v_offset */
0
,
/* s_offset */
0
,
/* immediate offset, instruction offset */
0
/* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
__builtin_hcu_raw_buffer_load_lds
(
read_buffer
,
lds
+
lane_id
*
4
,
16
,
(
block_offset
+
thread_offset
)
<<
1
,
/* v_offset */
0
,
/* s_offset */
0
,
/* immediate offset, instruction offset */
0
/* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
#else
*
(
vec4_fp32
*
)(
lds
+
lane_id
*
4
)
=
*
(
vec4_fp32
*
)(
read_ptr
+
((
block_offset
+
thread_offset
)
>>
1
));
#endif
// 从 LDS 转置后, 64 个线程写 4 行, 每次写 128 个 Half, 对应 fetch * 4 + [0,3] 的 seqlen
vec2_fp32
data0
,
data1
;
inlineasm_fa_ds_read2_b32
(
lds
,
lane_id
,
data0
,
0
,
64
);
inlineasm_fa_ds_read2_b32
(
lds
,
lane_id
+
128
,
data1
,
0
,
64
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
vec2_fp32
data0
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
lds
+
lane_id
,
0
,
64
,
false
);
vec2_fp32
data1
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
lds
+
lane_id
+
128
,
0
,
64
,
false
);
write_ptr
[(
min
(
actual_seqlen
-
1
,
fetch
*
4
+
0
)
*
num_heads
*
kHeadDim
+
(
lane_id
<<
1
))
>>
1
]
=
data0
[
0
];
write_ptr
[(
min
(
actual_seqlen
-
1
,
fetch
*
4
+
1
)
*
num_heads
*
kHeadDim
+
(
lane_id
<<
1
))
>>
1
]
=
data0
[
1
];
write_ptr
[(
min
(
actual_seqlen
-
1
,
fetch
*
4
+
2
)
*
num_heads
*
kHeadDim
+
(
lane_id
<<
1
))
>>
1
]
=
data1
[
0
];
...
...
@@ -370,4 +363,4 @@ __global__ void flash_fwd_varlen_permute_bhsd2bshd<256, 1, 32>(
#endif
#endif
\ No newline at end of file
csrc/flash_attn_hg/src/target/flash_varlen_fwd_permute_bshd2bhsd_hdim128.cpp
View file @
518a5f4d
#ifdef BUILD_FA_PERMUTE
#include <hip/hip_runtime.h>
#include "../../include/intrinsic.h"
#include "../flash_fwd_permute_hdim128.h"
...
...
@@ -118,29 +117,23 @@ __global__ void flash_fwd_varlen_permute_bshd2bhsd<128, 4, 0>(
// 接下来, 这个 block 要读取 4x128 的内容, 15 个线程读取一行 128 个 half(这里写死了 head_dim = 128), 每个线程读取 8 个 half
int32_t
thread_offset
=
lane_id_row
*
128
+
lane_id_col
*
8
;
// block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下
#if defined(__gfx936__) || defined(__gfx938__)
{
auto
*
lds_ptr
=
(
__attribute__
((
address_space
(
3
)))
int
*
)(
reinterpret_cast
<
size_t
>
(
lds
)
+
static_cast
<
size_t
>
(
lane_id
*
4
)
*
sizeof
(
float
));
__builtin_hcu_raw_buffer_load_lds
(
read_buffer
,
lds_ptr
,
16
,
(
block_offset
+
thread_offset
)
<<
1
,
/* v_offset */
0
,
/* s_offset */
0
,
/* immediate offset, instruction offset */
0
/* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
__builtin_hcu_raw_buffer_load_lds
(
read_buffer
,
lds
+
lane_id
*
4
,
16
,
(
block_offset
+
thread_offset
)
<<
1
,
/* v_offset */
0
,
/* s_offset */
0
,
/* immediate offset, instruction offset */
0
/* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
#else
*
(
vec4_fp32
*
)(
lds
+
lane_id
*
4
)
=
*
(
vec4_fp32
*
)(
read_ptr
+
((
block_offset
+
thread_offset
)
>>
1
));
#endif
// 写到 lds 不需要同步, 因为只有一个 wave
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
vec2_fp32
data0
,
data1
;
inlineasm_fa_ds_read2_b32
(
lds
,
lane_id
,
data0
,
0
,
64
);
inlineasm_fa_ds_read2_b32
(
lds
,
lane_id
+
128
,
data1
,
0
,
64
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
vec2_fp32
data0
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
lds
+
lane_id
,
0
,
64
,
false
);
vec2_fp32
data1
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
lds
+
lane_id
+
128
,
0
,
64
,
false
);
write_ptr
[(
fetch
*
head_dim
+
(
lane_id
<<
1
)
+
0
*
cur_seqlen_q
*
head_dim
)
>>
1
]
=
data0
[
0
];
write_ptr
[(
fetch
*
head_dim
+
(
lane_id
<<
1
)
+
1
*
cur_seqlen_q
*
head_dim
)
>>
1
]
=
data0
[
1
];
write_ptr
[(
fetch
*
head_dim
+
(
lane_id
<<
1
)
+
2
*
cur_seqlen_q
*
head_dim
)
>>
1
]
=
data1
[
0
];
...
...
@@ -267,29 +260,23 @@ __global__ void flash_fwd_varlen_permute_bshd2bhsd<128, 4, 32>(
// 接下来, 这个 block 要读取 4x128 的内容, 15 个线程读取一行 128 个 half(这里写死了 head_dim = 128), 每个线程读取 8 个 half
int32_t
thread_offset
=
lane_id_row
*
128
+
lane_id_col
*
8
;
// block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下
#if defined(__gfx936__) || defined(__gfx938__)
{
auto
*
lds_ptr
=
(
__attribute__
((
address_space
(
3
)))
int
*
)(
reinterpret_cast
<
size_t
>
(
lds
)
+
static_cast
<
size_t
>
(
lane_id
*
4
)
*
sizeof
(
float
));
__builtin_hcu_raw_buffer_load_lds
(
read_buffer
,
lds_ptr
,
16
,
(
block_offset
+
thread_offset
)
<<
1
,
/* v_offset */
0
,
/* s_offset */
0
,
/* immediate offset, instruction offset */
0
/* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
__builtin_hcu_raw_buffer_load_lds
(
read_buffer
,
lds
+
lane_id
*
4
,
16
,
(
block_offset
+
thread_offset
)
<<
1
,
/* v_offset */
0
,
/* s_offset */
0
,
/* immediate offset, instruction offset */
0
/* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
#else
*
(
vec4_fp32
*
)(
lds
+
lane_id
*
4
)
=
*
(
vec4_fp32
*
)(
read_ptr
+
((
block_offset
+
thread_offset
)
>>
1
));
#endif
// 写到 lds 不需要同步, 因为只有一个 wave
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
vec2_fp32
data0
,
data1
;
inlineasm_fa_ds_read2_b32
(
lds
,
lane_id
,
data0
,
0
,
64
);
inlineasm_fa_ds_read2_b32
(
lds
,
lane_id
+
128
,
data1
,
0
,
64
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
vec2_fp32
data0
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
lds
+
lane_id
,
0
,
64
,
false
);
vec2_fp32
data1
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
lds
+
lane_id
+
128
,
0
,
64
,
false
);
write_ptr
[(
seqlen_limit
*
head_dim
+
(
lane_id
<<
1
)
+
0
*
cur_seqlen_q
*
head_dim
)
>>
1
]
=
data0
[
0
];
write_ptr
[(
seqlen_limit
*
head_dim
+
(
lane_id
<<
1
)
+
1
*
cur_seqlen_q
*
head_dim
)
>>
1
]
=
data0
[
1
];
write_ptr
[(
seqlen_limit
*
head_dim
+
(
lane_id
<<
1
)
+
2
*
cur_seqlen_q
*
head_dim
)
>>
1
]
=
data1
[
0
];
...
...
@@ -352,7 +339,7 @@ __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<128,
// 接下来, 这个 block 要读取 4x128 的内容, 15 个线程读取一行 128 个 half(这里写死了 head_dim = 128), 每个线程读取 8 个 half
int32_t
thread_offset
=
lane_id_row
*
128
+
lane_id_col
*
8
;
// block 地址 + thread 地址, << 1 是获取偏移的字节数, 写到 lds 是为了转置一下
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
int
m0_offset
=
reinterpret_cast
<
size_t
>
(
lds
)
+
(
fetch
*
256
<<
2
);
int
offset_v
=
(
block_offset
+
thread_offset
)
<<
1
;
asm
volatile
(
...
...
@@ -377,15 +364,14 @@ __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<128,
// 把所有的 buffer_load 指令下发之后, 再从 lds 开始读取
#pragma unroll
for
(
int32_t
fetch
=
0
;
fetch
<
SEQLEN_PER_BLOCK
;
++
fetch
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(%0)
\n
"
::
"B"
(
SEQLEN_PER_BLOCK
-
fetch
-
1
));
__builtin_amdgcn_sched_barrier
(
0
);
#endif
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
inlineasm_fa_ds_read2_b32
(
lds
,
fetch
*
256
+
lane_id
,
registers_buffer
[
fetch
*
2
],
0
,
64
);
inlineasm_fa_ds_read2_b32
(
lds
,
fetch
*
256
+
lane_id
+
128
,
registers_buffer
[
fetch
*
2
+
1
],
0
,
64
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
registers_buffer
[
fetch
*
2
]
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
lds
+
fetch
*
256
+
lane_id
,
0
,
64
,
false
);
registers_buffer
[
fetch
*
2
+
1
]
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
lds
+
fetch
*
256
+
lane_id
+
128
,
0
,
64
,
false
);
}
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -394,7 +380,7 @@ __global__ void __launch_bounds__(64, 1) flash_fwd_varlen_permute_bshd2bhsd<128,
for
(
int32_t
fetch
=
0
;
fetch
<
SEQLEN_PER_BLOCK
;
++
fetch
)
{
// 限制边界
int32_t
seqlen_limit
=
min
(
actual_seqlen
-
1
,
fetch
);
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
// 计算固定的偏移, 字节数目
int32_t
v_addr
=
(
seqlen_limit
*
head_dim
<<
1
)
+
(
lane_id
<<
2
);
// 循环 seqlen_q 次, 每次间隔 4 x 128 个 half, 需要写 4 次
...
...
@@ -688,4 +674,4 @@ template<>
__global__
void
__launch_bounds__
(
64
,
1
)
flash_fwd_varlen_permute_bshd2bhsd
<
256
,
4
,
32
>
(
void
*
output
,
void
*
query
,
void
*
split_sizes
,
int64_t
head_stride
,
int32_t
num_heads
,
int
real_headdim
)
{}
#endif // end of BUILD_FA_PERMUTE
#endif // end of BUILD_FA_PERMUTE
\ No newline at end of file
flash_attn/__init__.py
View file @
518a5f4d
...
...
@@ -8,6 +8,8 @@ if torch.cuda.is_available():
flash_attn_qkvpacked_func
,
flash_attn_varlen_func
,
hg_flash_attn_varlen_func
,
flash_mla_with_kvcache
,
get_mla_metadata
,
vllm_flash_attn_varlen_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
...
...
flash_attn/flash_attn_interface.py
View file @
518a5f4d
...
...
@@ -2,7 +2,6 @@
from
typing
import
Optional
,
Union
from
typing
import
List
,
Tuple
import
os
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -19,13 +18,19 @@ from flash_attn.utils.sparse_utils import hyperparameter_check, get_block_map_me
DEFAULT_FA_VERSION
=
2
try
:
torch
.
_C
.
_dispatch_find_schema_or_throw
(
"flash_attn2_c_op::varlen_fwd"
,
""
)
_has_flash_attn2_c_varlen_fwd
=
True
except
RuntimeError
:
_has_flash_attn2_c_varlen_fwd
=
False
def
maybe_contiguous
(
x
):
return
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
def
round_multiple
(
x
,
m
):
return
(
x
+
m
-
1
)
//
m
*
m
def
_get_block_size_n
(
device
,
head_dim
,
is_dropout
,
is_causal
):
# This should match the block sizes in the CUDA kernel
...
...
@@ -59,7 +64,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
elif
head_dim
<=
512
:
return
64
if
torch
.
__version__
>=
"2.4.0"
:
if
torch
.
__version__
>=
"2.4.0"
and
_has_flash_attn2_c_varlen_fwd
:
_torch_custom_op_wrapper
=
torch
.
library
.
custom_op
_torch_register_fake_wrapper
=
torch
.
library
.
register_fake
else
:
...
...
@@ -199,7 +204,11 @@ def varlen_fwd_fake(
return
out
,
softmax_lse
,
p
,
rng_state
_wrapped_flash_attn_varlen_forward
=
torch
.
ops
.
flash_attn2_c_op
.
varlen_fwd
_wrapped_flash_attn_varlen_forward
=
(
torch
.
ops
.
flash_attn2_c_op
.
varlen_fwd
if
_has_flash_attn2_c_varlen_fwd
else
_flash_attn_varlen_forward
)
def
_flash_attn_backward
(
...
...
@@ -596,7 +605,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
return_softmax
,
bhsd
=
False
):
if
softmax_scale
is
None
:
...
...
@@ -611,7 +620,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
bhsd
=
bhsd
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
...
...
@@ -1922,7 +1931,7 @@ def vllm_flash_attn_varlen_func(
# Version selector
fa_version
:
int
=
DEFAULT_FA_VERSION
,
s_aux
=
None
,
):
):
"""
仅用于vllm prefix cache
dropout_p should be set to 0.0 during evaluation
...
...
@@ -1994,7 +2003,7 @@ def vllm_flash_attn_varlen_func(
else
:
assert
len
(
window_size
)
==
2
real_window_size
=
(
window_size
[
0
],
window_size
[
1
])
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
dummy_cu_seqlens_k
=
torch
.
empty_like
(
cu_seqlens_q
)
...
...
@@ -2005,7 +2014,7 @@ def vllm_flash_attn_varlen_func(
bs
=
cu_seqlens_q
.
shape
[
0
]
-
1
total_q
=
q
.
shape
[
0
]
# max_seqlen_q*bs==total_q and max_seqlen_q<=4 means mtp
# if mtp, k head must be 1.
# if mtp, k head must be 1.
# todo : support k head >1
is_mtp
=
(
max_seqlen_q
*
bs
==
total_q
and
max_seqlen_q
>
1
and
max_seqlen_q
<
5
)
if
(
max_seqlen_q
==
1
or
is_mtp
)
and
real_window_size
[
0
]
==-
1
:
...
...
@@ -2015,9 +2024,9 @@ def vllm_flash_attn_varlen_func(
else
:
out
=
torch
.
empty_like
(
q
)
flash_attn_cuda
.
paged_attention
(
out
,
q
.
reshape
(
bs
,
max_seqlen_q
,
q
.
shape
[
1
],
q
.
shape
[
-
1
]),
k
,
v
,
softmax_scale
,
block_table
,
seqused_k
,
alibi_slopes
,
kv_cache_dtype
,
q_descale
,
k_descale
,
v_descale
,
max_seqlen_k
,
s_aux
)
seqused_k
,
alibi_slopes
,
kv_cache_dtype
,
q_descale
,
k_descale
,
v_descale
,
max_seqlen_k
,
s_aux
)
return
out
is_938
=
"gfx938"
in
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
is_938
=
(
"gfx938"
in
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
or
"gfx92a"
in
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
)
if
(
not
is_938
)
and
k
.
dtype
==
torch
.
float8_e5m2
and
v
.
dtype
==
torch
.
float8_e5m2
:
assert
q
.
dtype
!=
torch
.
float8_e5m2
,
"UnSupport q.dtype:fp8"
q_descale
=
None
...
...
@@ -2048,7 +2057,7 @@ def vllm_flash_attn_varlen_func(
None
,
s_aux
,
)
else
:
else
:
if
(
k
.
dtype
==
torch
.
float8_e4m3fn
or
k
.
dtype
==
torch
.
float8_e5m2
)
and
q
.
dtype
!=
k
.
dtype
:
if
q_descale
is
not
None
:
q
=
q
/
q_descale
...
...
@@ -2059,7 +2068,7 @@ def vllm_flash_attn_varlen_func(
v
,
out
,
cu_seqlens_q
,
# cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
# cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
# still wants it so we pass all zeros
dummy_cu_seqlens_k
if
cu_seqlens_k
is
None
else
cu_seqlens_k
,
seqused_k
,
...
...
@@ -2092,7 +2101,7 @@ def vllm_flash_attn_varlen_func(
v
,
out
,
cu_seqlens_q
,
# cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
# cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
# still wants it so we pass all zeros
dummy_cu_seqlens_k
if
cu_seqlens_k
is
None
else
cu_seqlens_k
,
seqused_k
,
...
...
@@ -2334,6 +2343,7 @@ def flash_attn_with_kvcache(
assert
v_cache
.
stride
(
-
1
)
==
1
,
"v_cache must have contiguous last dimension"
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
s_aux
=
maybe_contiguous
(
s_aux
)
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
if
cache_seqlens
is
not
None
and
isinstance
(
cache_seqlens
,
int
):
...
...
@@ -2646,7 +2656,7 @@ def sparse_attn_varlen_func(
block_count and block_offset for slash sparsity patterns, and
column_count and column_index for vertical sparsity patterns.
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
...
...
@@ -2682,7 +2692,7 @@ def sparse_attn_varlen_func(
"""
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
out
,
softmax_lse
=
flash_attn_cuda
.
varlen_fwd_sparse
(
q
,
...
...
@@ -2723,7 +2733,7 @@ def varlen_fwd_unified(
softmax_scale
=
None
,
causal
=
False
,
softcap
=
0.0
,
window_size
=
(
-
1
,
-
1
),
window_size
=
(
-
1
,
-
1
),
alibi_slopes
=
None
,
use_alibi_sqrt
=
False
,
qq_bias
=
None
,
...
...
@@ -2732,15 +2742,125 @@ def varlen_fwd_unified(
*
,
out
=
None
,
return_softmax_lse
=
False
,
q_descale
=
None
,
k_descale
=
None
,
v_descale
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
window_size_left
,
window_size_right
=
window_size
if
k
.
dtype
.
is_floating_point
:
k_dtype_bits
=
torch
.
finfo
(
k
.
dtype
).
bits
else
:
k_dtype_bits
=
torch
.
iinfo
(
k
.
dtype
).
bits
is_mtp
=
(
max_seqlen_q
*
seqused_k
.
size
(
0
)
==
q
.
shape
[
0
]
and
1
<
max_seqlen_q
<
16
)
if
max_seqlen_q
>=
16
:
fp8_dtypes
=
[
torch
.
float8_e4m3fn
]
if
hasattr
(
torch
,
"float8_e4m3fnuz"
):
fp8_dtypes
.
append
(
torch
.
float8_e4m3fnuz
)
if
k_dtype_bits
==
16
and
q
.
shape
[
-
1
]
==
256
and
v
.
shape
[
-
1
]
==
256
:
out
,
softmax_lse
=
flash_attn_cuda
.
varlen_fwd_unified
(
q
,
k
,
v
,
out
,
cu_seqlens_q
,
max_seqlen_q
,
seqused_k
,
max_seqlen_k
,
block_table
,
softmax_scale
,
softcap
,
None
,
# q_descale
None
,
# k_descale
None
,
# v_descale
None
,
# output_scale
causal
,
window_size_left
,
window_size_right
,
alibi_slopes
,
use_alibi_sqrt
,
qq_bias
,
s_aux
,
mm_prefix_range
,
)
return
(
out
,
softmax_lse
)
if
return_softmax_lse
else
out
bshd_prefill
=
_require_hg_varlen_symbol
(
"hg_prefix_prefill_varlen_fwd"
)
fa_output
,
*
rest_extend
=
bshd_prefill
(
q
,
k
,
v
,
out
,
# out_
cu_seqlens_q
,
None
,
# cu_seqlens_k
seqused_k
,
alibi_slopes
,
block_table
,
max_seqlen_q
,
max_seqlen_k
,
0.0
,
# dropout
softmax_scale
,
False
,
# zero_tensors
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
return_softmax_lse
,
1
,
None
if
(
k_dtype_bits
==
16
)
else
q_descale
,
None
if
(
k_dtype_bits
==
16
)
else
k_descale
,
None
if
(
k_dtype_bits
==
16
)
else
v_descale
,
s_aux
,
True
,
)
return
(
fa_output
,
rest_extend
[
0
])
if
return_softmax_lse
else
fa_output
bs
=
seqused_k
.
size
(
0
)
total_q
=
q
.
shape
[
0
]
if
max_seqlen_q
==
1
or
is_mtp
:
if
k_dtype_bits
==
16
and
s_aux
is
not
None
:
raise
RuntimeError
(
"b16 prefix decode with attention sink is not supported by unified attention yet"
)
assert
not
use_alibi_sqrt
and
qq_bias
is
None
and
mm_prefix_range
is
None
,
\
f
"Arguments not supported in hg_bshd_decode"
bshd_pa_decode
=
_require_hg_varlen_symbol
(
"hg_prefix_decode_varlen_fwd"
)
result
=
bshd_pa_decode
(
q
,
k
,
v
,
out
,
cu_seqlens_q
,
None
,
seqused_k
,
alibi_slopes
,
block_table
,
max_seqlen_q
,
max_seqlen_k
,
0.0
,
softmax_scale
,
False
,
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
return_softmax_lse
,
1
,
None
if
(
k_dtype_bits
==
16
)
else
q_descale
,
None
if
(
k_dtype_bits
==
16
)
else
k_descale
,
None
if
(
k_dtype_bits
==
16
)
else
v_descale
,
s_aux
,
True
,
)
fa_output
=
result
[
0
]
return
(
fa_output
,
result
[
1
])
if
return_softmax_lse
else
fa_output
out
,
softmax_lse
=
flash_attn_cuda
.
varlen_fwd_unified
(
q
,
k
,
...
...
@@ -3830,8 +3950,8 @@ def get_block_map_fast(q, k, topk_ratio, BLKQ=128, BLKK=64):
sparse_map
=
torch
.
zeros_like
(
pooled_score
,
dtype
=
torch
.
int8
)
sparse_map
.
scatter_
(
-
1
,
lut
,
1
)
return
sparse_map
,
lut
,
topk
class
SparseLinearAttention
(
nn
.
Module
):
def
__init__
(
self
,
head_dim
,
topk
,
feature_map
=
'softmax'
,
use_bf16
=
True
,
use_fp8
=
False
,
tie_feature_map_qk
=
True
):
R
'''
...
...
@@ -3877,7 +3997,7 @@ class SparseLinearAttention(nn.Module):
with
torch
.
no_grad
():
nn
.
init
.
zeros_
(
self
.
proj_l
.
weight
)
nn
.
init
.
zeros_
(
self
.
proj_l
.
bias
)
def
forward
(
self
,
q
,
k
,
v
,
return_sparsity
=
False
):
R
'''
Args:
...
...
@@ -3886,18 +4006,18 @@ class SparseLinearAttention(nn.Module):
v: values of shape (B, L, H, D).
return_sparsity: whether to return the actual sparsity.
'''
B
,
seqlen_q
,
H
,
headdim
=
q
.
shape
if
headdim
==
64
:
block_m
=
64
if
seqlen_q
<=
2048
else
128
elif
headdim
==
128
:
block_m
=
64
if
seqlen_q
<=
2048
else
128
block_k
=
64
block_k
=
64
if
headdim
==
64
:
sparse_map
,
lut
,
real_topk
=
get_block_map
(
q
.
transpose
(
1
,
2
).
contiguous
(),
k
.
transpose
(
1
,
2
).
contiguous
(),
topk_ratio
=
self
.
topk
,
BLKQ
=
block_m
,
BLKK
=
block_k
)
else
:
sparse_map
,
lut
,
real_topk
=
get_block_map_fast
(
q
,
k
,
topk_ratio
=
self
.
topk
,
BLKQ
=
block_m
,
BLKK
=
block_k
)
q
=
q
.
to
(
self
.
dtype
)
k
=
k
.
to
(
self
.
dtype
)
v
=
v
.
to
(
self
.
dtype
)
...
...
@@ -3981,7 +4101,7 @@ def sparse_attn_with_sla(
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
dtype
=
torch
.
bfloat16
if
use_bf16
else
torch
.
float16
...
...
@@ -3994,12 +4114,12 @@ def sparse_attn_with_sla(
block_m
=
64
if
seqlen_q
<=
2048
else
128
elif
headdim
==
128
:
block_m
=
64
if
seqlen_q
<=
2048
else
128
block_k
=
64
block_k
=
64
if
headdim
==
64
:
sparse_map
,
lut
,
real_topk
=
get_block_map
(
q
.
transpose
(
1
,
2
).
contiguous
(),
k
.
transpose
(
1
,
2
).
contiguous
(),
topk_ratio
=
topk
,
BLKQ
=
block_m
,
BLKK
=
block_k
)
else
:
sparse_map
,
lut
,
real_topk
=
get_block_map_fast
(
q
,
k
,
topk_ratio
=
topk
,
BLKQ
=
block_m
,
BLKK
=
block_k
)
q
=
q
.
to
(
dtype
)
k
=
k
.
to
(
dtype
)
v
=
v
.
to
(
dtype
)
...
...
@@ -4045,15 +4165,6 @@ def _require_hg_varlen_symbol(name: str):
return
symbol
def
_apply_hg_kvcache_safe_env
()
->
None
:
# DTK/gfx938 PA launch is sensitive to these knobs. Keep the old HG-safe
# defaults unless the caller explicitly asks for the raw kernel selection.
if
os
.
environ
.
get
(
"HG_KVCACHE_RAW_KERNEL"
)
==
"1"
:
return
os
.
environ
.
setdefault
(
"PA_NO_MLS"
,
"1"
)
os
.
environ
.
setdefault
(
"PA_USE_TILE32X32"
,
"1"
)
def
_validate_hg_paged_kv_contract
(
k_cache
,
v_cache
)
->
None
:
if
k_cache
.
dim
()
!=
4
or
v_cache
.
dim
()
!=
4
:
raise
ValueError
(
"HG paged KV cache expects k and v to both be 4D tensors"
)
...
...
@@ -4066,53 +4177,56 @@ def _validate_hg_paged_kv_contract(k_cache, v_cache) -> None:
"v=[num_blocks, page_block_size, num_heads_k, d_v]"
)
def
_normalize_hg_paged_q_scales
(
q_scale
,
batch_size
,
num_heads_q
,
num_heads_k
):
if
q_scale
is
None
:
raise
ValueError
(
"q_descale must be provided for HG int8 paged-kvcache path"
)
q_scale
=
maybe_contiguous
(
q_scale
)
if
q_scale
.
dim
()
==
1
:
if
q_scale
.
numel
()
==
batch_size
*
num_heads_q
:
q_scale
=
q_scale
.
view
(
batch_size
,
num_heads_q
)
elif
q_scale
.
numel
()
==
batch_size
*
num_heads_k
:
q_scale
=
q_scale
.
view
(
batch_size
,
num_heads_k
)
if
q_scale
.
dim
()
!=
2
or
q_scale
.
shape
[
0
]
!=
batch_size
:
raise
ValueError
(
"q_descale must have shape [batch_size, num_heads_q] "
"or [batch_size, num_heads_k] for HG int8 paged-kvcache path"
)
if
q_scale
.
shape
[
1
]
==
num_heads_q
:
return
q_scale
.
contiguous
()
if
q_scale
.
shape
[
1
]
==
num_heads_k
and
num_heads_q
%
num_heads_k
==
0
:
return
q_scale
.
repeat_interleave
(
num_heads_q
//
num_heads_k
,
dim
=
1
).
contiguous
()
raise
ValueError
(
"q_descale must have shape [batch_size, num_heads_q] "
"or [batch_size, num_heads_k] for HG int8 paged-kvcache path"
)
def
get_mla_metadata
(
cache_seqlens
:
torch
.
Tensor
,
num_heads_per_head_k
:
int
,
num_heads_k
:
int
,
is_fp8_kvcache
:
bool
=
False
,
):
return
None
,
None
def
_expand_hg_paged_kv_scales
(
scale
,
block_table
,
page_block_size
,
num_heads_k
,
name
):
if
scale
is
None
:
raise
ValueError
(
f
"
{
name
}
must be provided for HG int8 paged-kvcache path"
)
scale
=
maybe_contiguous
(
scale
)
batch_size
=
block_table
.
shape
[
0
]
if
scale
.
dim
()
==
1
and
scale
.
numel
()
==
batch_size
*
num_heads_k
:
scale
=
scale
.
view
(
batch_size
,
num_heads_k
)
if
scale
.
dim
()
!=
2
or
scale
.
shape
!=
(
batch_size
,
num_heads_k
):
raise
ValueError
(
f
"
{
name
}
must have shape [batch_size, num_heads_k] for HG int8 paged-kvcache path"
def
flash_mla_with_kvcache
(
q
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
head_dim_v
:
int
,
tile_scheduler_metadata
:
Optional
[
torch
.
Tensor
],
num_splits
:
Optional
[
torch
.
Tensor
],
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
use_cuda_graph
:
bool
=
True
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
k_cache
.
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
raise
NotImplementedError
(
"HG MLA dispatch in the main repository supports fp16/bf16 only; "
"fp8/int8 MLA is not supported."
)
expanded
=
torch
.
empty
(
(
int
(
block_table
.
max
().
item
())
+
1
,
page_block_size
,
num_heads_k
),
device
=
scale
.
device
,
dtype
=
scale
.
dtype
,
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
hg_mla
=
_require_hg_varlen_symbol
(
"hg_fwd_kvcache_mla"
)
max_seqlen_k
=
1
if
use_cuda_graph
else
cache_seqlens
.
max
().
item
()
result
=
hg_mla
(
q
,
k_cache
,
None
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
out
,
max_seqlen_k
,
)
for
batch_idx
in
range
(
batch_size
):
block_ids
=
block_table
[
batch_idx
].
to
(
dtype
=
torch
.
long
)
expanded
[
block_ids
]
=
scale
[
batch_idx
].
view
(
1
,
1
,
num_heads_k
).
expand
(
block_ids
.
numel
(),
page_block_size
,
num_heads_k
)
return
expanded
.
contiguous
()
if
len
(
result
)
<
2
:
raise
RuntimeError
(
"hg_fwd_kvcache_mla did not return softmax_lse"
)
return
result
[
0
],
result
[
1
]
def
hg_flash_attn_varlen_func
(
q
,
...
...
@@ -4247,8 +4361,6 @@ def hg_flash_attn_varlen_func(
unsupported
.
append
(
"num_splits"
)
if
fa_version
!=
2
:
unsupported
.
append
(
"fa_version"
)
if
s_aux
is
not
None
:
unsupported
.
append
(
"s_aux"
)
if
custom_mask
is
not
None
:
unsupported
.
append
(
"custom_mask"
)
if
unsupported
:
...
...
@@ -4266,6 +4378,7 @@ def hg_flash_attn_varlen_func(
raise
ValueError
(
"cu_seqlens_q must be provided"
)
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
s_aux
=
maybe_contiguous
(
s_aux
)
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
...
...
@@ -4333,10 +4446,6 @@ def hg_flash_attn_varlen_func(
raise
ValueError
(
"cu_seqlens_k and seqused_k cannot be provided at the same time"
)
if
block_table
is
None
:
raise
ValueError
(
"block_table must be provided when seqused_k is used"
)
if
return_attn_probs
:
raise
NotImplementedError
(
"return_attn_probs is not supported for HG prefix/paged compatibility paths"
)
if
dropout_p
!=
0.0
:
raise
NotImplementedError
(
"dropout_p must be 0.0 for HG prefix/paged compatibility paths"
)
...
...
@@ -4346,6 +4455,33 @@ def hg_flash_attn_varlen_func(
k_dtype_bits
=
torch
.
iinfo
(
k
.
dtype
).
bits
if
max_seqlen_q
>
16
or
(
k_dtype_bits
==
8
and
max_seqlen_q
>
1
):
if
k_dtype_bits
==
16
and
q
.
shape
[
-
1
]
==
256
and
v
.
shape
[
-
1
]
==
256
:
out
,
softmax_lse
=
flash_attn_cuda
.
varlen_fwd_unified
(
q
,
k
,
v
,
out
,
cu_seqlens_q
,
max_seqlen_q
,
seqused_k
,
max_seqlen_k
,
block_table
,
softmax_scale
,
softcap
,
None
,
# q_descale
None
,
# k_descale
None
,
# v_descale
None
,
# output_scale
causal
,
window_size
[
0
],
window_size
[
1
],
alibi_slopes
,
False
,
# use_alibi_sqrt
None
,
# qq_bias
s_aux
,
None
,
# mm_prefix_range
)
return
(
out
,
softmax_lse
)
if
wants_aux
else
out
prefix_prefill
=
_require_hg_varlen_symbol
(
"hg_prefix_prefill_varlen_fwd"
)
result
=
prefix_prefill
(
q
,
...
...
@@ -4366,15 +4502,16 @@ def hg_flash_attn_varlen_func(
window_size
[
0
],
window_size
[
1
],
softcap
,
return_softmax_lse
,
return_softmax_lse
or
return_attn_probs
,
1
,
None
if
k_dtype_bits
==
16
else
q_descale
,
None
if
k_dtype_bits
==
16
else
k_descale
,
None
if
k_dtype_bits
==
16
else
v_descale
,
s_aux
,
is_bf16_output
,
)
fa_output
=
result
[
0
]
return
(
fa_output
,
result
[
1
])
if
return_softmax_lse
else
fa_output
return
(
fa_output
,
result
[
1
])
if
(
return_softmax_lse
or
return_attn_probs
)
else
fa_output
if
k_dtype_bits
==
16
:
prefix_decode
=
_require_hg_varlen_symbol
(
"hg_prefix_decode_varlen_fwd"
)
...
...
@@ -4397,13 +4534,18 @@ def hg_flash_attn_varlen_func(
window_size
[
0
],
window_size
[
1
],
softcap
,
return_softmax_lse
,
return_softmax_lse
or
return_attn_probs
,
1
,
None
,
None
,
None
,
s_aux
,
is_bf16_output
,
)
fa_output
=
result
[
0
]
return
(
fa_output
,
result
[
1
])
if
return_softmax_lse
else
fa_output
return
(
fa_output
,
result
[
1
])
if
(
return_softmax_lse
or
return_attn_probs
)
else
fa_output
if
return_softmax_lse
:
if
return_softmax_lse
or
return_attn_probs
:
raise
NotImplementedError
(
"return_softmax_lse is not supported for the HG paged-kvcache compatibility path"
)
...
...
@@ -4411,28 +4553,6 @@ def hg_flash_attn_varlen_func(
_validate_hg_paged_kv_contract
(
k
,
v
)
if
k
.
shape
[
1
]
!=
128
:
raise
NotImplementedError
(
"HG paged-kvcache path currently requires page_block_size == 128"
)
_apply_hg_kvcache_safe_env
()
q_descale
=
_normalize_hg_paged_q_scales
(
q_descale
,
batch_size
=
block_table
.
shape
[
0
],
num_heads_q
=
q
.
shape
[
1
],
num_heads_k
=
k
.
shape
[
2
],
)
k_descale
=
_expand_hg_paged_kv_scales
(
k_descale
,
block_table
=
block_table
,
page_block_size
=
k
.
shape
[
1
],
num_heads_k
=
k
.
shape
[
2
],
name
=
"k_descale"
,
)
v_descale
=
_expand_hg_paged_kv_scales
(
v_descale
,
block_table
=
block_table
,
page_block_size
=
v
.
shape
[
1
],
num_heads_k
=
v
.
shape
[
2
],
name
=
"v_descale"
,
)
hg_kvcache
=
_require_hg_varlen_symbol
(
"hg_fwd_kvcache_bshd"
)
result
=
hg_kvcache
(
q
.
unsqueeze
(
1
),
...
...
@@ -4442,7 +4562,7 @@ def hg_flash_attn_varlen_func(
None
,
None
,
seqused_k
,
max_seqlen_k
if
max_seqlen_k
>
0
else
int
(
seqused_k
.
max
().
item
())
,
1
,
None
,
None
,
None
,
...
...
@@ -4456,12 +4576,12 @@ def hg_flash_attn_varlen_func(
window_size
[
1
],
softcap
,
False
,
num_splits
,
-
1
,
None
,
None
,
q_descale
,
k_descale
,
v_descale
,
None
if
k_dtype_bits
==
16
else
q_descale
,
None
if
k_dtype_bits
==
16
else
k_descale
,
None
if
k_dtype_bits
==
16
else
v_descale
,
is_bf16_output
,
)
return
result
[
0
].
squeeze
(
1
)
setup.py
100644 → 100755
View file @
518a5f4d
...
...
@@ -52,6 +52,20 @@ SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE
FORCE_CXX11_ABI
=
os
.
getenv
(
"FLASH_ATTENTION_FORCE_CXX11_ABI"
,
"FALSE"
)
==
"TRUE"
def
cutlass_include_dirs
():
candidates
=
[
Path
(
this_dir
)
/
"csrc"
/
"cutlass"
/
"include"
,
]
cutlass_home
=
os
.
getenv
(
"CUTLASS_HOME"
)
if
cutlass_home
:
candidates
.
append
(
Path
(
cutlass_home
)
/
"include"
)
candidates
.
extend
([
Path
(
"/workspace/cutlass_3.2.1/include"
),
Path
(
"/public/home/huangly/数据采集/cutlass_3.2.1/include"
),
])
return
[
str
(
path
)
for
path
in
candidates
if
path
.
exists
()]
def
get_platform
():
"""
Returns the platform name as used in wheel filenames.
...
...
@@ -110,8 +124,18 @@ _HG_EXPLICIT_SOURCES_BY_MODE = {
"src/target/flash_fwd_hdim128_fp16.cpp"
,
"src/target/flash_fwd_hdim128_padding_mask_bf16.cpp"
,
"src/target/flash_fwd_hdim128_padding_mask_fp16.cpp"
,
"src/target/flash_fp8_fwd_hdim128_bf16.cpp"
,
"src/target/flash_fp8_fwd_hdim128_fp16.cpp"
,
"src/target/flash_fp8_fwd_hdim128_prefix_prefill_bf16.cpp"
,
"src/target/flash_fp8_fwd_hdim128_prefix_prefill_fp16.cpp"
,
"src/target/flash_fwd_hdim128_prefix_prefill_bf16.cpp"
,
"src/target/flash_fwd_hdim128_prefix_prefill_fp16.cpp"
,
"src/target/flash_fp8_fwd_hdim128_prefix_prefill_bf16.cpp"
,
"src/target/flash_fp8_fwd_hdim128_prefix_prefill_fp16.cpp"
,
"src/target/flash_fp8_fwd_hdimqk192_hdimv128_prefix_prefill_bf16.cpp"
,
"src/target/flash_fp8_fwd_hdimqk192_hdimv128_prefix_prefill_fp16.cpp"
,
"src/target/flash_fp8_fwd_hdim256_prefix_prefill_bf16.cpp"
,
"src/target/flash_fp8_fwd_hdim256_prefix_prefill_fp16.cpp"
,
"src/target/flash_fwd_hdim160_bf16.cpp"
,
"src/target/flash_fwd_hdim160_fp16.cpp"
,
"src/target/flash_fwd_hdim192_bf16.cpp"
,
...
...
@@ -262,13 +286,64 @@ def _ninja_shell_join(args) -> str:
return
" "
.
join
(
_ninja_escape
(
shlex
.
quote
(
str
(
x
)))
for
x
in
args
)
def
_resolve_hg_compiler
()
->
str
:
candidates
=
[
os
.
environ
.
get
(
"FLASH_ATTN_HG_COMPILER"
),
"/opt/dtk/bin/aicc"
,
"aicc"
,
]
for
compiler
in
candidates
:
if
compiler
and
shutil
.
which
(
compiler
):
return
compiler
requested
=
[
c
for
c
in
candidates
if
c
]
raise
RuntimeError
(
"error: no usable HG aicc compiler found from: "
+
", "
.
join
(
repr
(
c
)
for
c
in
requested
)
+
". Set FLASH_ATTN_HG_COMPILER to the DTK aicc path."
)
def
_normalize_hg_gfx_archs
(
gfx_version
:
str
):
archs
=
[]
for
item
in
str
(
gfx_version
).
replace
(
","
,
";"
).
split
(
";"
):
item
=
item
.
strip
()
if
not
item
:
continue
archs
.
append
(
item
if
item
.
startswith
(
"gfx"
)
else
f
"gfx
{
item
}
"
)
return
archs
def
_hg_target_define_value
(
archs
):
return
","
.
join
(
arch
[
3
:]
if
arch
.
startswith
(
"gfx"
)
else
arch
for
arch
in
archs
)
def
_hg_arch_device_flags
(
archs
):
flags
=
[]
for
arch
in
archs
:
if
arch
in
(
"gfx936"
,
"gfx938"
):
flags
.
extend
([
f
"-Xarch_
{
arch
}
"
,
"-mllvm=-support-768-vgprs=true"
])
elif
arch
in
(
"gfx928"
,
"gfx92a"
,
"gfx946"
):
flags
.
extend
([
f
"-Xarch_
{
arch
}
"
,
"-mllvm=-support-512-vgprs=true"
])
flags
.
extend
([
f
"-Xarch_
{
arch
}
"
,
"-mllvm=-co-issue-vgpr-size=256"
])
return
flags
def
_is_rocm_5_7
()
->
bool
:
version_file
=
"/opt/rocm/.info/version"
try
:
with
open
(
version_file
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
return
f
.
read
(
100
).
startswith
(
"5.7.0"
)
except
OSError
:
return
False
def
compute_hg_build_descriptor
(
src_dir
,
build_dir
,
mode
=
"all"
,
extra_options_raw
=
"-DGFX_VERSION=938 -Wl,-Bsymbolic"
,
extra_options_raw
=
"-DGFX_VERSION=938
;936
-Wl,-Bsymbolic"
,
):
"""Collect HG sources and flags for Ninja (no compile). Default: mode=all, gfx938."""
"""Collect HG sources and flags for Ninja (no compile). Default: mode=all, gfx938
/gfx936
."""
import
sysconfig
as
_sysconfig
src_dir
=
os
.
path
.
abspath
(
str
(
src_dir
))
...
...
@@ -279,7 +354,7 @@ def compute_hg_build_descriptor(
BUILD_FA_FWD
=
BUILD_FA_BWD
=
BUILD_FA_KVCACHE
=
False
BUILD_FA_PERMUTE
=
BUILD_FLASHMLA
=
False
BUILD_C_INTERFACE
=
False
BUILD_ASM
=
Fals
e
BUILD_ASM
=
Tru
e
FA_DEBUG
=
True
FA_DEBUG_SUM_MAX
=
False
HEADDIM_128_ONLY
=
False
...
...
@@ -355,14 +430,11 @@ def compute_hg_build_descriptor(
EXTRA_HIP_FLAGS
.
append
(
_tok
)
if
GFX_VERSION
is
None
:
GFX_VERSION
=
"938"
GFX_VERSION
=
"938
;936
"
ROCM_PATH
=
os
.
environ
.
get
(
"ROCM_PATH"
,
os
.
environ
.
get
(
"ROCM_HOME"
,
"/opt/rocm"
))
HG_COMPILER
=
_resolve_hg_compiler
()
if
not
shutil
.
which
(
"hipcc"
):
raise
RuntimeError
(
"error: hipcc not found in PATH. Please activate the DTK environment first."
)
if
not
os
.
path
.
isdir
(
os
.
path
.
join
(
ROCM_PATH
,
"include"
)):
raise
RuntimeError
(
f
"error:
{
ROCM_PATH
}
/include not found. "
...
...
@@ -400,7 +472,8 @@ def compute_hg_build_descriptor(
"-lc10"
,
]
_gfx_comma
=
GFX_VERSION
.
replace
(
";"
,
","
)
HG_ARCHS
=
_normalize_hg_gfx_archs
(
GFX_VERSION
)
_gfx_comma
=
_hg_target_define_value
(
HG_ARCHS
)
DEFINES
=
[
f
"-DTARGET=
{
_gfx_comma
}
"
,
"-D__HIP_PLATFORM_AMD__=1"
,
...
...
@@ -444,8 +517,10 @@ def compute_hg_build_descriptor(
DEFINES
.
append
(
"-DPA_PAGE_BLOCK_SIZE"
)
if
MLA_PAGE_BLOCK_SIZE
:
DEFINES
.
append
(
"-DMLA_PAGE_BLOCK_SIZE"
)
if
_is_rocm_5_7
():
DEFINES
.
append
(
"-DROCM_5_7"
)
OFFLOAD_FLAGS
=
[
f
"--offload-arch=
gfx
{
_g
}
"
for
_g
in
GFX_VERSION
.
split
(
";"
)
if
_g
]
OFFLOAD_FLAGS
=
[
f
"--offload-arch=
{
_g
}
"
for
_g
in
HG_ARCHS
]
INCLUDE_FLAGS
=
[
f
"-I
{
ROCM_PATH
}
/include"
,
...
...
@@ -457,23 +532,36 @@ def compute_hg_build_descriptor(
INCLUDE_FLAGS
+=
TORCH_INCLUDE_FLAGS
COMMON_FLAGS
=
[
"-fPIC"
,
"-O3"
,
"-std=c++17"
,
"-fPIC"
,
"-ffast-math"
,
"-fno-finite-math-only"
,
"-fno-gpu-rdc"
,
"-mno-fma"
,
]
DTK_DEVICE_FLAGS
=
[
"-DHIP_ENABLE_WARP_SYNC_BUILTINS"
,
"-mllvm"
,
"-
slp-phi-tree-bb-max-size=10000
"
,
"-
disable-machine-sink
"
,
"-mllvm"
,
"-enable-num-vgprs-512=true"
,
"-Rpass-analysis=kernel-resource-usage"
,
"-ftemplate-backtrace-limit=0"
,
"-disable-code-sink"
,
"-mcode-object-version=5"
,
]
if
not
_is_rocm_5_7
():
DTK_DEVICE_FLAGS
+=
[
"-mllvm"
,
"-amdgpu-enable-rewrite-partial-reg-uses=false"
,
"-mllvm"
,
"-allow-gvn-convergent-call=true"
,
"-mllvm"
,
"-disallow-uniform-vmed3-combine=true"
,
"-mllvm"
,
"-hcu-pre-emit-load-store-opt=false"
,
"-mllvm"
,
"-amdgpu-early-inline-all=true"
,
"-mllvm"
,
"-amdgpu-function-calls=false"
,
]
DTK_DEVICE_FLAGS
+=
_hg_arch_device_flags
(
HG_ARCHS
)
if
os
.
environ
.
get
(
"FLASH_ATTN_HG_SAVE_TEMPS"
,
""
)
==
"1"
:
DTK_DEVICE_FLAGS
.
append
(
"--save-temps"
)
...
...
@@ -555,6 +643,7 @@ def compute_hg_build_descriptor(
"obj_dir"
:
obj_dir
,
"sources"
:
_all_sources
,
"objects"
:
objects
,
"compiler"
:
HG_COMPILER
,
"compile_flags"
:
compile_flags
,
"link_flags"
:
link_flags
,
"out_so"
:
out_so
,
...
...
@@ -566,16 +655,17 @@ def run_hg_ninja_build(descriptor: dict) -> None:
"""Write build_hg.ninja and run ninja (parallel via MAX_JOBS)."""
build_dir
=
descriptor
[
"build_dir"
]
ninja_file
=
descriptor
[
"ninja_path"
]
compiler
=
_ninja_shell_join
([
descriptor
[
"compiler"
]])
out_so_ninja
=
_ninja_escape_path
(
descriptor
[
"out_so"
])
lines
=
[
"ninja_required_version = 1.3"
,
""
,
"rule h
ipcc
_compile"
,
" command =
hipcc
-c $in -o $out $FLAGS"
,
"rule h
g
_compile"
,
f
" command =
{
compiler
}
-c $in -o $out $FLAGS"
,
" description = HG compile $in"
,
""
,
"rule h
ipcc
_link"
,
" command =
hipcc
-shared -o $out @$out.rsp $LINK_FLAGS"
,
"rule h
g
_link"
,
f
" command =
{
compiler
}
-shared -o $out @$out.rsp $LINK_FLAGS"
,
" rspfile = $out.rsp"
,
" rspfile_content = $in"
,
" description = HG link $out"
,
...
...
@@ -585,9 +675,9 @@ def run_hg_ninja_build(descriptor: dict) -> None:
""
,
]
for
src
,
obj
in
zip
(
descriptor
[
"sources"
],
descriptor
[
"objects"
]):
lines
.
append
(
f
"build
{
_ninja_escape_path
(
obj
)
}
: h
ipcc
_compile
{
_ninja_escape_path
(
src
)
}
"
)
lines
.
append
(
f
"build
{
_ninja_escape_path
(
obj
)
}
: h
g
_compile
{
_ninja_escape_path
(
src
)
}
"
)
obj_list
=
" "
.
join
(
_ninja_escape_path
(
obj
)
for
obj
in
descriptor
[
"objects"
])
lines
.
append
(
f
"build
{
out_so_ninja
}
: h
ipcc
_link
{
obj_list
}
"
)
lines
.
append
(
f
"build
{
out_so_ninja
}
: h
g
_link
{
obj_list
}
"
)
lines
.
append
(
""
)
os
.
makedirs
(
build_dir
,
exist_ok
=
True
)
...
...
@@ -616,6 +706,7 @@ HG_BUILD_DIR = os.path.join(this_dir, "build", "flash_attn_hg")
HG_SO_BUILD
=
os
.
path
.
join
(
HG_BUILD_DIR
,
"libflash_attention.so"
)
HG_SO_PKG
=
os
.
path
.
join
(
this_dir
,
"flash_attn"
,
"lib"
,
"libflash_attention.so"
)
HG_LIB_DIR
=
os
.
path
.
dirname
(
HG_SO_PKG
)
os
.
environ
[
'PYTORCH_NVCC'
]
=
'aicc'
# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
# files included in the source distribution, in case the user compiles from source.
...
...
@@ -663,7 +754,21 @@ if not SKIP_CUDA_BUILD:
# HAS_HG_DISPATCH / -lflash_attention are applied there if the .so exists.
hg_compile_defs
=
[]
hg_link_args
=
[]
aicc_flags
=
[
"-mcode-object-version=5"
,
"-mllvm=-support-768-vgprs=true"
,
"-mllvm=-disable-machine-sink"
,
"-mllvm=-disable-code-sink"
,
"-mllvm=-amdgpu-enable-rewrite-partial-reg-uses=false"
,
"-mllvm=-allow-gvn-convergent-call=true"
,
"-mllvm=-disallow-uniform-vmed3-combine=true"
,
"-mllvm=-hcu-pre-emit-load-store-opt=false"
,
"-mllvm=-amdgpu-early-inline-all=true"
,
"-mllvm=-amdgpu-function-calls=false"
,
"-fno-finite-math-only"
,
"--gpu-max-threads-per-block=256"
,
"-mllvm=-unroll-threshold=10000"
]
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
...
...
@@ -896,7 +1001,7 @@ if not SKIP_CUDA_BUILD:
"-std=c++17"
,
"-DDCU_ASM"
,
# "-mllvm -not-combine-fma=true",
"-mllvm -slp-phi-tree-bb-max-size=10000"
,
#
"-mllvm -slp-phi-tree-bb-max-size=10000",
# "-mllvm -allow-cse-cross-bb-convergent-call=true",
# "-mllvm -full-vectorize-slp=true",
f
"-DFLASH_ATTENTION_BF16_TYPE=
{
bf16_type
}
"
,
...
...
@@ -936,6 +1041,7 @@ if not SKIP_CUDA_BUILD:
]
+
generator_flag
+
hg_compile_defs
+
aicc_flags
# + cc_flag
),
},
...
...
@@ -944,6 +1050,7 @@ if not SKIP_CUDA_BUILD:
Path
(
this_dir
)
/
"csrc"
/
"flash_attn"
,
Path
(
this_dir
)
/
"csrc"
/
"flash_attn"
/
"src"
,
Path
(
this_dir
)
/
"csrc"
/
"cutlass"
/
"include"
,
],
)
)
...
...
@@ -1051,13 +1158,16 @@ class NinjaBuildExtension(BuildExtension):
if
os
.
path
.
isdir
(
HG_SRC_DIR
):
os
.
makedirs
(
HG_BUILD_DIR
,
exist_ok
=
True
)
_maybe_clean_hg_build_dir
(
HG_BUILD_DIR
)
print
(
"=== Building HG libflash_attention.so (mode=all, gfx938, ninja) ==="
)
try
:
desc
=
compute_hg_build_descriptor
(
HG_SRC_DIR
,
HG_BUILD_DIR
,
mode
=
"all"
,
extra_options_raw
=
"-DGFX_VERSION=938 -Wl,-Bsymbolic"
,
extra_options_raw
=
"-DGFX_VERSION=938;936 -Wl,-Bsymbolic"
,
)
print
(
"=== Building HG libflash_attention.so "
f
"(mode=all, gfx938/gfx936, ninja, compiler=
{
desc
[
'compiler'
]
}
) ==="
)
run_hg_ninja_build
(
desc
)
if
os
.
path
.
isfile
(
HG_SO_BUILD
):
...
...
@@ -1066,11 +1176,11 @@ class NinjaBuildExtension(BuildExtension):
use_hg
=
True
print
(
f
"=== Copied HG .so ->
{
HG_SO_PKG
}
==="
)
else
:
print
(
"WARNING
: HG build completed but output .so is missing
; continuing without HG dispatch
"
)
raise
RuntimeError
(
"Error
: HG build completed but output .so is missing"
)
except
Exception
as
e
:
print
(
f
"WARNING: HG build failed (
{
e
}
), continuing without HG dispatch
"
)
raise
RuntimeError
(
f
"Error: HG build failed (
{
e
}
)
"
)
else
:
print
(
f
"WARNING
: HG source directory not found (
{
HG_SRC_DIR
}
)
, continuing without HG dispatch
"
)
raise
RuntimeError
(
f
"Error
: HG source directory not found (
{
HG_SRC_DIR
}
)"
)
else
:
# FLASH_BUILD_HG=0 should deterministically disable dispatch even if stale artifacts exist.
if
os
.
path
.
isfile
(
HG_SO_PKG
):
...
...
tests/prefix_decode_sglang_decode.py
0 → 100644
View file @
518a5f4d
import
os
import
sys
import
math
import
torch
import
pickle
import
time
import
numpy
import
argparse
import
random
from
datetime
import
datetime
use_cuda_toolkits
=
os
.
path
.
exists
(
"/usr/local/cuda/bin/nvcc"
)
use_rocm_toolkits
=
os
.
path
.
exists
(
"/opt/rocm/llvm/bin/clang"
)
use_dtk_toolkits
=
os
.
path
.
exists
(
"/opt/dtk/bin/aicc"
)
if
(
use_cuda_toolkits
):
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
elif
(
use_rocm_toolkits
or
use_dtk_toolkits
):
try
:
from
flash_attention_interface
import
flash_attn_varlen_func
,
flash_attn_2_cuda
,
flash_attn_with_kvcache
except
ModuleNotFoundError
:
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
import
flash_attn_2_cuda
as
flash_attn_cuda
def
_require_hg_varlen_symbol
(
name
:
str
):
symbol
=
getattr
(
flash_attn_cuda
,
name
,
None
)
if
symbol
is
None
:
raise
RuntimeError
(
f
"
{
name
}
is unavailable in this build. Rebuild flash_attn with HAS_HG_DISPATCH enabled."
)
return
symbol
def
cal_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
,
do_assert
=
True
,
cos_threshold
=
1e-5
)
->
None
:
assert
x
.
shape
==
y
.
shape
,
"for {}, x and y must have the same shape"
.
format
(
name
)
x
,
y
=
x
.
double
(),
y
.
double
()
RMSE
=
((
x
-
y
)
*
(
x
-
y
)).
mean
().
sqrt
().
item
()
cos_diff
=
1
-
2
*
(
x
*
y
).
sum
().
item
()
/
max
((
x
*
x
+
y
*
y
).
sum
().
item
(),
1e-12
)
amax_diff
=
(
x
-
y
).
abs
().
max
().
item
()
rel_diff_mean
=
(
x
/
y
).
abs
().
mean
().
item
()
rel_diff_max
=
(
x
/
y
).
abs
().
max
().
item
()
print
(
"name:{} cos_diff={:.12f}, RMSE=
\x1b
[35m{:.12f}
\x1b
[0m, amax_diff=
\x1b
[35m{:.12f}
\x1b
[0m, REL=
\x1b
[35m{:.12f}
\x1b
[0m, rel_max=
\x1b
[35m{:.12f}
\x1b
[0m"
.
format
(
name
,
cos_diff
,
RMSE
,
amax_diff
,
rel_diff_mean
,
rel_diff_max
))
if
(
do_assert
):
assert
cos_diff
<
cos_threshold
def
scaled_dot_product_attention
(
__query
,
__key
,
__value
,
h_q
,
h_kv
,
is_causal
=
False
,
USE_CPU
=
False
,
return_max_sum
=
False
,
original_seqlen_kv
=
0
,
split_slice
=
0
,
is_bshd
=
False
,
window_size
=
(
-
1
,
-
1
)):
__query
=
__query
.
transpose
(
0
,
1
).
contiguous
()
__key
=
__key
.
transpose
(
0
,
1
).
contiguous
()
__value
=
__value
.
transpose
(
0
,
1
).
contiguous
()
# 判断是否使用 CPU 计算 golden, 避免 blas 的影响
original_device
=
__query
.
device
original_dtype
=
__query
.
dtype
if
(
USE_CPU
):
__query
=
__query
.
cpu
()
__key
=
__key
.
cpu
()
__value
=
__value
.
cpu
()
# print("scaled_dot_product_attention: ", query.shape, key.shape, value.shape)
__query
=
__query
.
float
()
__key
=
__key
.
float
()
__value
=
__value
.
float
()
# 如果按照官方的方法返回
if
(
not
return_max_sum
):
__key
=
__key
.
repeat_interleave
(
h_q
//
h_kv
,
dim
=
0
)
__value
=
__value
.
repeat_interleave
(
h_q
//
h_kv
,
dim
=
0
)
attn_weight
=
__query
@
__key
.
transpose
(
-
2
,
-
1
)
/
math
.
sqrt
(
__query
.
size
(
-
1
))
# MTP > 1, causal/local mask applied
if
(
window_size
!=
(
-
1
,
-
1
)):
s_q
=
__query
.
shape
[
-
2
]
s_k
=
__key
.
shape
[
-
2
]
left
,
right
=
window_size
if
left
<
0
:
left
=
s_k
if
right
<
0
:
right
=
s_k
row_idx
=
torch
.
arange
(
s_q
,
dtype
=
torch
.
int32
,
device
=
attn_weight
.
device
)[:,
None
]
col_idx
=
torch
.
arange
(
s_k
,
dtype
=
torch
.
int32
,
device
=
attn_weight
.
device
)[
None
,
:]
col_idx_limit_left
=
row_idx
+
s_k
-
s_q
-
left
col_idx_limit_right
=
row_idx
+
s_k
-
s_q
+
right
temp_mask
=
(
col_idx
>=
col_idx_limit_left
)
&
(
col_idx
<=
col_idx_limit_right
)
attn_weight
=
attn_weight
.
masked_fill
(
temp_mask
.
logical_not
(),
float
(
"-inf"
))
elif
(
is_causal
):
s_q
=
__query
.
shape
[
-
2
]
s_k
=
__key
.
shape
[
-
2
]
attn_bias
=
torch
.
zeros
(
s_q
,
s_k
,
dtype
=
__query
.
dtype
,
device
=
attn_weight
.
device
)
temp_mask
=
torch
.
ones
(
s_q
,
s_k
,
dtype
=
torch
.
bool
,
device
=
attn_weight
.
device
).
tril
(
diagonal
=
s_k
-
s_q
)
attn_bias
.
masked_fill_
(
temp_mask
.
logical_not
(),
float
(
"-inf"
))
attn_bias
.
to
(
__query
.
dtype
)
attn_weight
+=
attn_bias
# some codes for debug
scores_max
=
attn_weight
.
to
(
torch
.
float32
).
max
(
-
1
)[
0
]
scores_sum
=
torch
.
exp
(
attn_weight
.
to
(
torch
.
float32
)
-
scores_max
.
unsqueeze
(
-
1
)).
sum
(
dim
=-
1
)
# original codes
lse
=
attn_weight
.
logsumexp
(
dim
=-
1
)
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
,
dtype
=
torch
.
float32
)
output
=
attn_weight
@
__value
output
=
output
.
transpose
(
0
,
1
).
contiguous
()
return
output
.
to
(
original_device
).
to
(
original_dtype
),
lse
.
to
(
original_device
),
scores_max
.
to
(
original_device
),
scores_sum
.
to
(
original_device
)
def
set_random_seed
(
seed
=
0
):
random
.
seed
(
seed
)
# 设置 Python 的随机种子
numpy
.
random
.
seed
(
seed
)
# 设置 NumPy 的随机种子
torch
.
manual_seed
(
seed
)
# 设置 PyTorch 的随机种子
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
seed
)
# 设置所有 GPU 的随机种子
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
os
.
environ
[
'OMP_NUM_THREADS'
]
=
'1'
# 设置 OpenMP 的线程数
torch
.
set_num_threads
(
1
)
# 设置 PyTorch 的线程数
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Process some integers.'
)
parser
.
add_argument
(
'--load'
,
default
=
False
,
action
=
'store_true'
,
help
=
'load path'
)
parser
.
add_argument
(
'--trace'
,
default
=
False
,
action
=
'store_true'
,
help
=
'whether dump perf traces'
)
parser
.
add_argument
(
'--bf16'
,
default
=
False
,
action
=
'store_true'
,
help
=
'whether use bfloat16 as main dtype'
)
parser
.
add_argument
(
'--fp8'
,
default
=
False
,
action
=
'store_true'
,
help
=
'whether use fp8_e4m3 inputs for HG decode'
)
parser
.
add_argument
(
'--pressure'
,
default
=
False
,
action
=
'store_true'
,
help
=
'whether do pressure test'
)
parser
.
add_argument
(
'--cpu'
,
default
=
False
,
action
=
'store_true'
,
help
=
'whether compute golden via cpu'
)
parser
.
add_argument
(
'--pad'
,
default
=
False
,
action
=
'store_true'
,
help
=
'whether make query uncontiguous to simulate vllm behaviors'
)
parser
.
add_argument
(
'--iterations'
,
type
=
int
,
default
=
100
,
help
=
'pressure test times'
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
128
,
help
=
'page block_size'
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
1
,
help
=
'batch size for generated inputs'
)
parser
.
add_argument
(
'--seq-q'
,
type
=
int
,
default
=
4
,
help
=
'query length per batch for generated inputs'
)
parser
.
add_argument
(
'--seq-k'
,
type
=
int
,
default
=
2048
,
help
=
'kv length per batch for generated inputs'
)
parser
.
add_argument
(
'--num-heads'
,
type
=
int
,
default
=
24
,
help
=
'number of query heads for generated inputs'
)
parser
.
add_argument
(
'--num-heads-kv'
,
type
=
int
,
default
=
2
,
help
=
'number of kv heads for generated inputs'
)
parser
.
add_argument
(
'--head-dim-qk'
,
type
=
int
,
default
=
128
,
help
=
'query/key head dimension'
)
parser
.
add_argument
(
'--head-dim-v'
,
type
=
int
,
default
=
128
,
help
=
'value head dimension'
)
parser
.
add_argument
(
'--no-causal'
,
dest
=
'causal'
,
default
=
True
,
action
=
'store_false'
,
help
=
'disable causal mask for generated inputs'
)
parser
.
add_argument
(
'--window-left'
,
type
=
int
,
default
=-
1
,
help
=
'left sliding window size'
)
parser
.
add_argument
(
'--window-right'
,
type
=
int
,
default
=-
1
,
help
=
'right sliding window size'
)
parser
.
add_argument
(
'--seed'
,
default
=
False
,
action
=
'store_true'
,
help
=
'whether do pressure test'
)
args
=
parser
.
parse_args
()
if
(
args
.
seed
):
set_random_seed
(
212
)
# 从文件加载输入
if
(
args
.
load
):
nvidia_packet
=
torch
.
load
(
"./demo.pt"
)
query
,
key
,
value
,
cu_seqlens_q
,
max_seqlen_q
,
cache_seqlens
,
max_seqlen_k
,
softmax_scale
,
causal
,
window_size
,
alibi_slopes
,
page_table
,
softcap
,
fa_version
,
q_descale
,
k_descale
,
v_descale
=
nvidia_packet
[
"inputs"
]
vllm_golden
=
nvidia_packet
[
"outputs"
]
# 解析出必要的参数
batch_size
=
page_table
.
shape
[
0
]
assert
batch_size
==
cu_seqlens_q
.
shape
[
0
]
-
1
,
"check batch size"
page_block_size
=
key
.
shape
[
1
]
num_heads_kv
=
key
.
shape
[
2
]
num_heads
=
query
.
shape
[
1
]
head_dim_qk
=
query
.
shape
[
2
]
head_dim_v
=
key
.
shape
[
3
]
infer_dtype
=
query
.
dtype
else
:
# 随机生成 seqkv
batch_size
=
args
.
batch_size
# 得到 Q 的长度
seqlen_q
=
[
args
.
seq_q
for
i
in
range
(
batch_size
)]
seqlen_q_sum
=
sum
(
seqlen_q
)
max_seqlen_q
=
max
(
seqlen_q
)
cu_seqlens_q
=
numpy
.
array
([
0
]
+
numpy
.
cumsum
(
seqlen_q
).
tolist
()).
astype
(
"int32"
)
cu_seqlens_q
=
torch
.
from_numpy
(
cu_seqlens_q
)
# 得到 KV 的长度
cache_seqlens
=
[
args
.
seq_k
for
i
in
range
(
batch_size
)]
# 指定分页块的大小, nvidia 64, ours 128
page_block_size
=
16
if
(
use_cuda_toolkits
)
else
args
.
block_size
# 根据分页块大小计算实际需要的页表的大小
max_seqlen_k
=
max
(
cache_seqlens
)
seqlen_kv_real_required_page
=
[
math
.
ceil
(
it
/
page_block_size
)
for
it
in
cache_seqlens
]
seqlen_kv_real_required_page_sum
=
sum
(
seqlen_kv_real_required_page
)
# 默认按照最大 seqlenkv 的来分配
seqlen_kv_max_required_page
=
math
.
ceil
(
max_seqlen_k
/
page_block_size
)
seqlen_kv_max_required_page_total
=
batch_size
*
seqlen_kv_max_required_page
# 打乱页表
shuffle
=
True
if
(
shuffle
):
block_random
=
torch
.
randperm
(
seqlen_kv_max_required_page_total
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
else
:
block_random
=
torch
.
arange
(
seqlen_kv_max_required_page_total
,
dtype
=
torch
.
int32
)
page_table
=
[]
seq_block_incre
=
0
for
i
in
range
(
batch_size
):
blocks_pad
=
[
0
]
*
seqlen_kv_max_required_page
if
(
shuffle
):
blocks_pad
[:
seqlen_kv_real_required_page
[
i
]]
=
block_random
[
seq_block_incre
:
seq_block_incre
+
seqlen_kv_real_required_page
[
i
]].
cpu
().
tolist
()
seq_block_incre
+=
seqlen_kv_real_required_page
[
i
]
else
:
blocks_pad
=
block_random
[
seq_block_incre
:
seq_block_incre
+
seqlen_kv_max_required_page
].
cpu
().
tolist
()
seq_block_incre
+=
seqlen_kv_max_required_page
page_table
.
append
(
torch
.
IntTensor
(
blocks_pad
))
page_table
=
torch
.
stack
(
page_table
).
contiguous
().
to
(
"cuda"
)
# 创建基本参数
head_dim_qk
=
args
.
head_dim_qk
head_dim_v
=
args
.
head_dim_v
num_heads
=
args
.
num_heads
num_heads_kv
=
args
.
num_heads_kv
infer_dtype
=
torch
.
float16
# deepseek 默认使用 bfloat16 推理
if
(
args
.
bf16
):
infer_dtype
=
torch
.
bfloat16
# 除非命令行指定用 fp16, 不受 args.dtype 影响
softmax_scale
=
1.0
/
math
.
sqrt
(
head_dim_qk
)
causal
=
args
.
causal
window_size
=
(
args
.
window_left
,
args
.
window_right
)
alibi_slopes
=
None
softcap
=
0.0
fa_version
=
2
q_descale
=
torch
.
ones
((
batch_size
,
num_heads
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
k_descale
=
torch
.
ones
((
batch_size
,
num_heads_kv
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
v_descale
=
torch
.
ones
((
batch_size
,
num_heads_kv
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
# 创建输入张量
if
(
args
.
pad
):
query_origin_tensor
=
torch
.
randn
((
seqlen_q_sum
,
num_heads
+
16
,
head_dim_qk
),
dtype
=
infer_dtype
,
device
=
"cuda"
)
q
=
query_origin_tensor
[:,
:
num_heads
]
else
:
q
=
torch
.
randn
((
seqlen_q_sum
,
num_heads
,
head_dim_qk
),
dtype
=
infer_dtype
,
device
=
"cuda"
)
k_cache
=
torch
.
randn
((
seqlen_kv_max_required_page_total
,
page_block_size
,
num_heads_kv
,
head_dim_qk
),
device
=
"cuda"
,
dtype
=
infer_dtype
)
v_cache
=
torch
.
randn
((
seqlen_kv_max_required_page_total
,
page_block_size
,
num_heads_kv
,
head_dim_v
),
device
=
"cuda"
,
dtype
=
infer_dtype
)
vllm_golden
=
None
cu_seqlens_q
=
cu_seqlens_q
.
to
(
q
.
device
)
cache_seqlens
=
torch
.
from_numpy
(
numpy
.
array
(
cache_seqlens
).
astype
(
"int32"
)).
to
(
q
.
device
)
q_ref
=
q
k_cache_ref
=
k_cache
v_cache_ref
=
v_cache
if
args
.
fp8
:
if
not
hasattr
(
torch
,
"float8_e4m3fn"
):
raise
RuntimeError
(
"This PyTorch build does not support torch.float8_e4m3fn"
)
q
=
q
.
to
(
torch
.
float8_e4m3fn
)
k_cache
=
k_cache
.
to
(
torch
.
float8_e4m3fn
)
v_cache
=
v_cache
.
to
(
torch
.
float8_e4m3fn
)
q_ref
=
q
.
to
(
infer_dtype
)
k_cache_ref
=
k_cache
.
to
(
infer_dtype
)
v_cache_ref
=
v_cache
.
to
(
infer_dtype
)
# 展示一下输入数据
print
(
"--------------------------------------------------------------------------------------------"
)
print
(
"q: "
,
q
.
shape
,
q
.
dtype
,
q
.
is_contiguous
(),
q
.
stride
())
print
(
"k_cache: "
,
k_cache
.
shape
,
k_cache
.
dtype
,
k_cache
.
is_contiguous
(),
k_cache
.
stride
())
print
(
"v_cache: "
,
v_cache
.
shape
,
v_cache
.
dtype
,
v_cache
.
is_contiguous
(),
v_cache
.
stride
())
print
(
"cu_seqlens_q: "
,
cu_seqlens_q
.
shape
,
cu_seqlens_q
.
dtype
,
cu_seqlens_q
.
is_contiguous
())
print
(
"cu_seqlens_q: "
,
cu_seqlens_q
)
print
(
"max_seqlen_q: "
,
max_seqlen_q
)
print
(
"cache_seqlens: "
,
cache_seqlens
)
print
(
"max_seqlen_k: "
,
max_seqlen_k
)
print
(
"softmax_scale: "
,
softmax_scale
)
print
(
"causal: "
,
causal
)
print
(
"window_size: "
,
window_size
)
print
(
"alibi_slopes: "
,
alibi_slopes
)
print
(
"page_table: "
,
page_table
.
shape
,
page_table
.
dtype
,
page_table
.
is_contiguous
(),
page_table
.
stride
())
print
(
"page_table: "
,
page_table
)
print
(
"softcap: "
,
softcap
)
print
(
"fa_version: "
,
fa_version
)
print
(
"q_descale: "
,
q_descale
.
shape
,
q_descale
.
dtype
,
q_descale
.
tolist
())
print
(
"k_descale: "
,
k_descale
.
shape
,
k_descale
.
dtype
,
k_descale
.
tolist
())
print
(
"v_descale: "
,
v_descale
.
shape
,
v_descale
.
dtype
,
v_descale
.
tolist
())
print
(
"--------------------------------------------------------------------------------------------"
)
# 先从 kvcache 中还原出 key 和 value
key_original
=
[]
value_original
=
[]
for
b
in
range
(
batch_size
):
# 获取页表索引
index
=
page_table
[
b
]
# 获取实际的索引
max_page_blocks
=
math
.
ceil
(
cache_seqlens
[
b
]
/
page_block_size
)
actual_index
=
index
[:
max_page_blocks
]
# 根据该页表索引获取当前 seqlenkv 的内容
key_content
=
k_cache_ref
[
actual_index
]
# reshape 回去
key_content
=
key_content
.
view
(
-
1
,
num_heads_kv
,
head_dim_qk
)[:
cache_seqlens
[
b
]].
contiguous
()
# 同理
value_content
=
v_cache_ref
[
actual_index
].
view
(
-
1
,
num_heads_kv
,
head_dim_v
)[:
cache_seqlens
[
b
]].
contiguous
()
key_original
.
append
(
key_content
)
value_original
.
append
(
value_content
)
# 同理还原出 query 的内容
query_original
=
[]
cum_q
=
0
for
b
in
range
(
batch_size
):
query_len
=
cu_seqlens_q
[
b
+
1
]
-
cu_seqlens_q
[
b
]
query_content
=
q_ref
[
cum_q
:
cum_q
+
query_len
]
query_original
.
append
(
query_content
.
contiguous
())
cum_q
+=
query_len
# 重新实现 self-attention
golden
=
[]
golden_lse
=
[]
golden_max
=
[]
for
b
in
range
(
batch_size
):
tmp_output
,
lse
,
scores_max
,
scores_sum
=
scaled_dot_product_attention
(
query_original
[
b
],
key_original
[
b
],
value_original
[
b
],
num_heads
,
num_heads_kv
,
is_causal
=
causal
,
USE_CPU
=
args
.
cpu
,
window_size
=
window_size
)
golden
.
append
(
tmp_output
)
golden_lse
.
append
(
lse
)
golden_max
.
append
(
scores_max
)
golden
=
torch
.
cat
(
golden
,
dim
=
0
)
golden_lse
=
torch
.
cat
(
golden_lse
,
dim
=-
1
)
golden_max
=
torch
.
cat
(
golden_max
,
dim
=-
1
)
print
(
"golden: "
,
golden
.
shape
)
print
(
"golden_lse: "
,
golden_lse
.
shape
)
print
(
"--------------------------------------------------------------------------------------------"
)
if
(
True
):
# fa_output, fa_lse = flash_attn_2_cuda.prefix_decode_varlen_fwd(
bshd_pa_decode
=
_require_hg_varlen_symbol
(
"hg_prefix_decode_varlen_fwd"
)
fa_output
,
fa_lse
=
bshd_pa_decode
(
q
,
k_cache
,
v_cache
,
None
,
# out_
cu_seqlens_q
,
None
,
# cu_seqlens_k
cache_seqlens
,
alibi_slopes
,
page_table
,
max_seqlen_q
,
max_seqlen_k
,
0.0
,
# dropout
softmax_scale
,
False
,
# zero_tensors
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
True
,
# return_softmax_lse,
1
,
q_descale
if
args
.
fp8
else
None
,
k_descale
if
args
.
fp8
else
None
,
v_descale
if
args
.
fp8
else
None
,
None
,
# s_aux
infer_dtype
==
torch
.
bfloat16
,
)
else
:
fa_output
,
fa_lse
,
*
rest
=
flash_attn_with_kvcache
(
q
=
q
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
page_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k_new
=
cache_seqlens
,
max_seqlen_q
=
max_seqlen_q
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
)
torch
.
cuda
.
synchronize
()
if
(
vllm_golden
is
not
None
):
# 检查保存流程是否有错误
cal_diff
(
fa_output
,
vllm_golden
,
"check"
)
print
(
"fa_output: "
,
fa_output
.
shape
)
if
(
fa_lse
is
not
None
):
print
(
"fa_lse: "
,
fa_lse
.
shape
)
# 检验精度如何
fp8_threshold
=
5e-3
cal_diff
(
golden
,
fa_output
,
"accuracy"
,
True
,
fp8_threshold
if
args
.
fp8
else
1e-5
)
if
(
fa_lse
is
not
None
):
cal_diff
(
golden_lse
,
fa_lse
,
"softmax_lse"
,
True
,
fp8_threshold
if
args
.
fp8
else
1e-5
)
print
(
"--------------------------------------------------------------------------------------------"
)
# benchmark 性能数据
import
triton
def
benchmark_prefix_prefill
():
_
=
bshd_pa_decode
(
q
,
k_cache
,
v_cache
,
None
,
cu_seqlens_q
,
None
,
cache_seqlens
,
alibi_slopes
,
page_table
,
max_seqlen_q
,
max_seqlen_k
,
0.0
,
softmax_scale
,
False
,
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
True
,
1
,
q_descale
if
args
.
fp8
else
None
,
k_descale
if
args
.
fp8
else
None
,
v_descale
if
args
.
fp8
else
None
,
None
,
infer_dtype
==
torch
.
bfloat16
,
)
# 适时关闭, 用于 debug
if
((
os
.
getenv
(
"FA_DEBUG"
)
is
None
)
and
(
os
.
getenv
(
"HIP_LOG_LEVEL"
)
is
None
)
and
not
args
.
trace
):
import
triton
t
=
triton
.
testing
.
do_bench_cudagraph
(
benchmark_prefix_prefill
)
FLOPS
=
float
(
0
)
BYTES
=
float
(
0
)
for
b
in
range
(
batch_size
):
batch_seqlen_q
=
cu_seqlens_q
[
b
+
1
]
-
cu_seqlens_q
[
b
]
batch_seqlen_k
=
cache_seqlens
[
b
]
effective_seqlen_k
=
batch_seqlen_k
if
window_size
!=
(
-
1
,
-
1
):
window_left
,
window_right
=
window_size
left
=
batch_seqlen_k
if
window_left
<
0
else
window_left
right
=
batch_seqlen_k
if
window_right
<
0
else
window_right
effective_seqlen_k
=
min
(
batch_seqlen_k
,
left
+
batch_seqlen_q
+
right
)
undo_flops
=
batch_seqlen_q
*
batch_seqlen_q
/
2
if
(
causal
and
window_size
==
(
-
1
,
-
1
))
else
0
attn_elems
=
batch_seqlen_q
*
effective_seqlen_k
-
undo_flops
qk_flops
=
num_heads
*
attn_elems
*
head_dim_qk
*
2
pv_flops
=
num_heads
*
attn_elems
*
head_dim_v
*
2
FLOPS
+=
qk_flops
+
pv_flops
q_load
=
batch_seqlen_q
*
num_heads
*
head_dim_qk
k_load
=
effective_seqlen_k
*
num_heads_kv
*
head_dim_qk
# k load not only once
v_load
=
effective_seqlen_k
*
num_heads_kv
*
head_dim_v
BYTES
+=
q_load
*
q
.
element_size
()
+
k_load
*
k_cache
.
element_size
()
+
v_load
*
v_cache
.
element_size
()
# ignore storation ?
print
(
f
"Performance:
{
t
:.
3
f
}
ms,
\x1b
[35m
{
FLOPS
/
10
**
9
/
t
:.
2
f
}
\x1b
[0m TFLOPS,
\x1b
[35m
{
BYTES
/
10
**
6
/
t
:.
0
f
}
\x1b
[0m GB/s"
)
# 压力测试
if
(
args
.
pressure
):
pressure_count
=
max
(
100
,
args
.
iterations
)
for
p
in
range
(
pressure_count
):
pressure_fa_output
=
torch
.
zeros_like
(
fa_output
)
pressure_fa_output
,
_
=
bshd_pa_decode
(
q
.
clone
(),
k_cache
.
clone
(),
v_cache
.
clone
(),
None
,
cu_seqlens_q
,
None
,
cache_seqlens
,
alibi_slopes
,
page_table
,
max_seqlen_q
,
max_seqlen_k
,
0.0
,
softmax_scale
,
False
,
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
True
,
1
,
q_descale
if
args
.
fp8
else
None
,
k_descale
if
args
.
fp8
else
None
,
v_descale
if
args
.
fp8
else
None
,
infer_dtype
==
torch
.
bfloat16
,
)
torch
.
cuda
.
synchronize
()
is_equal
=
torch
.
equal
(
pressure_fa_output
,
fa_output
)
if
(
not
is_equal
):
cal_diff
(
pressure_fa_output
,
fa_output
,
"pressure"
)
assert
is_equal
,
"
\x1b
[31mUnstable
\x1b
[0m!"
del
pressure_fa_output
sys
.
stdout
.
write
(
"
\r
Pressure Test: {}/{}"
.
format
(
p
+
1
,
pressure_count
))
print
(
"
\x1b
[32mPASS
\x1b
[0m"
)
print
(
"-----------------------------------------------------------------------------------"
)
tests/test_unified_attn.py
View file @
518a5f4d
import
torch
import
os
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
...
...
@@ -6,6 +7,39 @@ from vllm.triton_utils import tl, triton
import
math
import
time
from
typing
import
Optional
UNIFIED_BLOCK_SIZE
=
int
(
os
.
getenv
(
"UNIFIED_BLOCK_SIZE"
,
"128"
))
def
estimate_unified_attention_bytes
(
batch_size
,
seqlen_q
,
seqlen_k
,
nheads
,
nheads_k
,
d
,
block_size
,
q_bytes
,
k_bytes
,
v_bytes
,
out_bytes
=
0
,
d_v
=
None
,
window_size
=
(
-
1
,
-
1
),
):
d_v
=
d
if
d_v
is
None
else
d_v
effective_seqlen_k
=
seqlen_k
if
window_size
!=
(
-
1
,
-
1
):
window_left
,
window_right
=
window_size
left
=
seqlen_k
if
window_left
<
0
else
window_left
right
=
seqlen_k
if
window_right
<
0
else
window_right
effective_seqlen_k
=
min
(
seqlen_k
,
left
+
seqlen_q
+
right
)
num_blocks
=
math
.
ceil
(
effective_seqlen_k
/
block_size
)
q_bytes_total
=
batch_size
*
seqlen_q
*
nheads
*
d
*
q_bytes
kv_bytes_total
=
effective_seqlen_k
*
batch_size
*
nheads_k
*
(
d
*
k_bytes
+
d_v
*
v_bytes
)
out_bytes_total
=
batch_size
*
seqlen_q
*
nheads
*
d_v
*
out_bytes
metadata_bytes
=
(
batch_size
+
1
)
*
4
+
batch_size
*
4
+
batch_size
*
num_blocks
*
4
return
q_bytes_total
+
kv_bytes_total
+
out_bytes_total
+
metadata_bytes
import
pytest
import
torch
import
torch.nn.functional
as
F
...
...
@@ -18,19 +52,8 @@ import pdb
from
einops
import
rearrange
,
repeat
from
flash_attn
import
(
flash_attn_func
,
flash_attn_kvpacked_func
,
flash_attn_qkvpacked_func
,
flash_attn_varlen_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_with_kvcache
,
varlen_fwd_unified
,
)
from
flash_attn
import
flash_attn_func
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
flash_attn.flash_attn_interface
import
_get_block_size_n
from
flash_attn.layers.rotary
import
apply_rotary_emb
MAX_HEADDIM_SM8x
=
192
...
...
@@ -109,6 +132,7 @@ def kernel_unified_attention_2d(
TILE_SIZE
:
tl
.
constexpr
,
# int must be power of 2
HEAD_SIZE
:
tl
.
constexpr
,
# int
HEAD_SIZE_PADDED
:
tl
.
constexpr
,
# int, must be power of 2
VALUE_HEAD_SIZE
:
tl
.
constexpr
,
# int
USE_ALIBI_SLOPES
:
tl
.
constexpr
,
# bool
USE_ALIBI_SQRT
:
tl
.
constexpr
,
# bool
USE_QQ_BIAS
:
tl
.
constexpr
,
# bool
...
...
@@ -167,6 +191,7 @@ def kernel_unified_attention_2d(
)
dim_mask
=
tl
.
where
(
offs_d
<
HEAD_SIZE
,
1
,
0
).
to
(
tl
.
int1
)
value_dim_mask
=
tl
.
where
(
offs_d
<
VALUE_HEAD_SIZE
,
1
,
0
).
to
(
tl
.
int1
)
query_mask_0
=
tl
.
where
(
query_pos
<
cur_batch_query_len
,
1
,
0
).
to
(
tl
.
int1
)
query_mask_1
=
tl
.
where
(
query_offset_1
<
num_query_heads
,
1
,
0
).
to
(
tl
.
int1
)
...
...
@@ -296,7 +321,7 @@ def kernel_unified_attention_2d(
# V : (TILE_SIZE, HEAD_SIZE)
V_load
=
tl
.
load
(
value_cache_ptr
+
v_offset
,
mask
=
dim_mask
[
None
,
:]
&
tile_mask
[:,
None
],
mask
=
value_
dim_mask
[
None
,
:]
&
tile_mask
[:,
None
],
other
=
0.0
,
)
...
...
@@ -425,7 +450,7 @@ def kernel_unified_attention_2d(
tl
.
store
(
output_ptr
+
output_offset
,
acc
,
mask
=
dim_mask
[
None
,
:]
&
query_mask_0
[:,
None
]
&
query_mask_1
[:,
None
],
mask
=
value_
dim_mask
[
None
,
:]
&
query_mask_0
[:,
None
]
&
query_mask_1
[:,
None
],
)
...
...
@@ -1039,6 +1064,7 @@ def unified_attention(
TILE_SIZE
=
TILE_SIZE_PREFILL
,
HEAD_SIZE
=
head_size
,
HEAD_SIZE_PADDED
=
triton
.
next_power_of_2
(
head_size
),
VALUE_HEAD_SIZE
=
v
.
shape
[
3
],
USE_ALIBI_SLOPES
=
use_alibi_slopes
,
USE_ALIBI_SQRT
=
use_alibi_sqrt
,
USE_QQ_BIAS
=
use_qq_bias
,
...
...
@@ -1235,7 +1261,14 @@ def make_paged_kv(k_list, v_list, block_size, device, dtype):
return
k_cache
,
v_cache
,
block_table
def
cal_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
,
use_fp8
:
bool
=
False
,
is_e5m2
:
bool
=
False
)
->
None
:
def
cal_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
,
use_fp8
:
bool
=
False
,
is_e5m2
:
bool
=
False
,
cos_threshold
:
Optional
[
float
]
=
None
,
)
->
None
:
torch_dtype
=
x
.
dtype
x
,
y
=
x
.
double
(),
y
.
double
()
RMSE
=
((
x
-
y
)
*
(
x
-
y
)).
mean
().
sqrt
().
item
()
...
...
@@ -1245,23 +1278,60 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False,
if
is_e5m2
:
assert
cos_diff
<
1e-2
elif
use_fp8
:
assert
cos_diff
<
1e-3
assert
cos_diff
<
5e-3
elif
cos_threshold
is
not
None
:
assert
cos_diff
<
cos_threshold
else
:
assert
cos_diff
<
(
1e-4
if
torch_dtype
==
torch
.
bfloat16
else
1e-5
)
class
_NoUnifiedFallback
:
def
__init__
(
self
,
module
):
self
.
_module
=
module
def
__getattr__
(
self
,
name
):
if
name
==
"varlen_fwd_unified"
:
raise
AssertionError
(
"unexpected fallback to flash_attn_cuda.varlen_fwd_unified"
)
return
getattr
(
self
.
_module
,
name
)
def
varlen_fwd_unified_expect_hg
(
expected_symbol
,
*
args
,
**
kwargs
):
fn_globals
=
varlen_fwd_unified
.
__globals__
original_module
=
fn_globals
[
"flash_attn_cuda"
]
original_require
=
fn_globals
[
"_require_hg_varlen_symbol"
]
called_symbols
=
[]
def
require_hg_symbol
(
name
):
called_symbols
.
append
(
name
)
return
original_require
(
name
)
fn_globals
[
"flash_attn_cuda"
]
=
_NoUnifiedFallback
(
original_module
)
fn_globals
[
"_require_hg_varlen_symbol"
]
=
require_hg_symbol
try
:
result
=
varlen_fwd_unified
(
*
args
,
**
kwargs
)
finally
:
fn_globals
[
"flash_attn_cuda"
]
=
original_module
fn_globals
[
"_require_hg_varlen_symbol"
]
=
original_require
assert
expected_symbol
in
called_symbols
,
(
f
"expected
{
expected_symbol
}
, got HG calls
{
called_symbols
}
"
)
return
result
# ---------------------------------------------------------------------------
# Accuracy tests
# ---------------------------------------------------------------------------
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"use_fp8"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"gqa"
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"softcap"
,
[
0.0
])
@
pytest
.
mark
.
parametrize
(
"window_size"
,
[(
-
1
,
-
1
)])
@
pytest
.
mark
.
parametrize
(
"window_size"
,
[(
-
1
,
-
1
)
,
(
511
,
0
)
])
@
pytest
.
mark
.
parametrize
(
"use_alibi_sqrt"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_qq_bias"
,
[
True
,
False
])
# seqlen_q > seqlen_k 时 skip
@
pytest
.
mark
.
parametrize
(
"use_sinks"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_mm_prefix"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"d
,d_v
"
,
[
(
128
,
128
),
(
192
,
128
),
(
256
,
256
)
])
@
pytest
.
mark
.
parametrize
(
"batch_size,seqlen_q,seqlen_k,block_size"
,
[
...
...
@@ -1274,43 +1344,63 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False,
# --- 场景 2: Decode 场景 (增量推理) ---
# 验证 seqlen_q=1 时,如何正确从 KV Cache 的最后位置读取信息
# 此时 qq_bias 实际上只退化为向量加法,最容易测出指针偏移错误
(
8
,
1
,
2048
,
128
),
# 高 Batch 的标准 Decode
(
1
,
1
,
4096
,
128
),
# 超长上下文 Decode,验证大索引寻址
# --- 场景 3: Chunked Prefill / Speculative Decoding (分段/投机采样) ---
# Q 小于 K,但大于 1。这是最难写的逻辑,验证 Is_causal 的动态截断
(
2
,
128
,
1024
,
128
),
# Q 是一小段,K 是长历史
(
4
,
256
,
512
,
128
),
# 验证 Q 和 K 比例较近时的处理
# --- 场景 4: 边界非对称尺寸 (非 2 的幂次) ---
# 专门用来抓那些“假设数据一定是 BlockSize 整数倍”的 Bug
(
64
,
1
,
2048
,
UNIFIED_BLOCK_SIZE
),
(
16
,
1
,
2048
,
UNIFIED_BLOCK_SIZE
),
(
1
,
1
,
4096
,
UNIFIED_BLOCK_SIZE
),
(
64
,
4
,
2048
,
UNIFIED_BLOCK_SIZE
),
(
32
,
4
,
2048
,
UNIFIED_BLOCK_SIZE
),
(
16
,
4
,
2048
,
UNIFIED_BLOCK_SIZE
),
(
8
,
4
,
2048
,
UNIFIED_BLOCK_SIZE
),
# 高 Batch 的标准 Decode
(
4
,
4
,
2048
,
UNIFIED_BLOCK_SIZE
),
(
2
,
4
,
2048
,
UNIFIED_BLOCK_SIZE
),
(
1
,
4
,
4096
,
UNIFIED_BLOCK_SIZE
),
# 超长上下文 Decode,验证大索引寻址
# --- 场景 3: Prefix Prefill ---
(
1
,
16
,
128
,
UNIFIED_BLOCK_SIZE
),
# fp8 prefill lower boundary
(
1
,
32
,
512
,
UNIFIED_BLOCK_SIZE
),
(
2
,
32
,
513
,
UNIFIED_BLOCK_SIZE
),
# non block-aligned KV length
(
2
,
64
,
2048
,
UNIFIED_BLOCK_SIZE
),
(
3
,
96
,
1537
,
UNIFIED_BLOCK_SIZE
),
# non power-of-two batch/length
(
2
,
128
,
1024
,
UNIFIED_BLOCK_SIZE
),
# # --- 场景 4: 边界非对称尺寸 (非 2 的幂次) ---
# # 专门用来抓那些“假设数据一定是 BlockSize 整数倍”的 Bug
(
1
,
127
,
127
,
128
),
# 刚好差 1 个填满 Block
(
2
,
33
,
1025
,
128
),
# 非常细碎的 Block 和不规则长度
],
)
def
test_unified_attn_2d
(
batch_size
,
seqlen_q
,
seqlen_k
,
block_size
,
d
,
causal
,
window_size
,
softcap
,
d
,
d_v
,
causal
,
window_size
,
softcap
,
mha_type
,
dtype
,
use_alibi_sqrt
,
use_qq_bias
,
use_sinks
,
use_mm_prefix
,
use_alibi_sqrt
,
use_qq_bias
,
use_sinks
,
use_mm_prefix
,
use_fp8
,
):
device
=
torch
.
device
(
"cuda"
)
torch
.
manual_seed
(
42
)
nheads
=
8
nheads_k
=
1
if
mha_type
==
"gqa"
else
nheads
nheads
=
16
nheads_k
=
2
if
mha_type
==
"gqa"
else
nheads
softmax_scale
=
d
**
(
-
0.5
)
MAX_MM_RANGES
=
2
# skip invalid combos
if
use_alibi_sqrt
and
not
causal
:
pytest
.
skip
(
"alibi_sqrt only tested with causal=True"
)
if
use_alibi_sqrt
:
pytest
.
skip
(
"HG unified attention does not support alibi_sqrt yet"
)
if
use_qq_bias
and
seqlen_q
>
seqlen_k
:
pytest
.
skip
(
"qq_bias requires seqlen_q <= seqlen_k"
)
if
use_qq_bias
:
pytest
.
skip
(
"HG unified attention does not support qq_bias yet"
)
if
use_mm_prefix
and
seqlen_q
>
seqlen_k
:
pytest
.
skip
(
"mm_prefix not supported when seqlen_q > seqlen_k"
)
if
use_mm_prefix
:
pytest
.
skip
(
"HG unified attention does not support mm_prefix yet"
)
if
(
not
use_fp8
)
and
use_sinks
and
(
seqlen_q
==
1
or
1
<
seqlen_q
<
16
):
pytest
.
skip
(
"b16 prefix decode sinks are not supported yet"
)
if
(
not
use_fp8
)
and
use_sinks
and
seqlen_q
>=
16
and
d
==
256
and
d_v
==
256
:
pytest
.
skip
(
"b16 prefix prefill 256/256 sinks are not supported yet"
)
# if use_mm_prefix and not causal:
# pytest.skip("mm_prefix_range is only meaningful with causal=True")
...
...
@@ -1318,15 +1408,30 @@ def test_unified_attn_2d(
for
_
in
range
(
batch_size
):
q_list
.
append
(
torch
.
randn
(
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
))
k_list
.
append
(
torch
.
randn
(
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
))
v_list
.
append
(
torch
.
randn
(
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
))
v_list
.
append
(
torch
.
randn
(
seqlen_k
,
nheads_k
,
d_v
,
device
=
device
,
dtype
=
dtype
))
if
use_fp8
:
fp8_dtype
=
current_platform
.
fp8_dtype
()
q_kernel_list
=
[
q
.
to
(
fp8_dtype
)
for
q
in
q_list
]
k_kernel_list
=
[
k
.
to
(
fp8_dtype
)
for
k
in
k_list
]
v_kernel_list
=
[
v
.
to
(
fp8_dtype
)
for
v
in
v_list
]
q_ref_list
=
[
q
.
to
(
dtype
)
for
q
in
q_kernel_list
]
k_ref_list
=
[
k
.
to
(
dtype
)
for
k
in
k_kernel_list
]
v_ref_list
=
[
v
.
to
(
dtype
)
for
v
in
v_kernel_list
]
else
:
q_kernel_list
,
k_kernel_list
,
v_kernel_list
=
q_list
,
k_list
,
v_list
q_ref_list
,
k_ref_list
,
v_ref_list
=
q_list
,
k_list
,
v_list
q_varlen
=
torch
.
cat
(
q_list
,
dim
=
0
)
q_varlen
=
torch
.
cat
(
q_
kernel_
list
,
dim
=
0
)
cu_seqlens_q
=
torch
.
zeros
(
batch_size
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
cu_seqlens_q
[
1
:]
=
torch
.
cumsum
(
torch
.
tensor
([
seqlen_q
]
*
batch_size
,
dtype
=
torch
.
int32
),
dim
=
0
)
seqused_k
=
torch
.
tensor
([
seqlen_k
]
*
batch_size
,
device
=
device
,
dtype
=
torch
.
int32
)
k_cache
,
v_cache
,
block_table
=
make_paged_kv
(
k_list
,
v_list
,
block_size
,
device
,
dtype
)
k_cache
,
v_cache
,
block_table
=
make_paged_kv
(
k_kernel_list
,
v_kernel_list
,
block_size
,
device
,
q_varlen
.
dtype
)
q_descale
=
torch
.
ones
((
batch_size
,
nheads
),
device
=
device
,
dtype
=
torch
.
float32
)
if
use_fp8
else
None
k_descale
=
torch
.
ones
((
batch_size
,
nheads_k
),
device
=
device
,
dtype
=
torch
.
float32
)
if
use_fp8
else
None
v_descale
=
torch
.
ones
((
batch_size
,
nheads_k
),
device
=
device
,
dtype
=
torch
.
float32
)
if
use_fp8
else
None
# Build optional tensors
alibi_slopes
=
None
...
...
@@ -1339,7 +1444,8 @@ def test_unified_attn_2d(
sinks
=
None
if
use_sinks
:
sinks
=
torch
.
randn
(
nheads
,
device
=
device
,
dtype
=
dtype
)
sink_dtype
=
torch
.
bfloat16
if
use_fp8
else
torch
.
float32
sinks
=
(
torch
.
randn
(
nheads
,
device
=
device
,
dtype
=
sink_dtype
)
*
0.25
)
+
2.0
mm_prefix_range
=
None
if
use_mm_prefix
:
...
...
@@ -1354,7 +1460,7 @@ def test_unified_attn_2d(
for
i
in
range
(
batch_size
):
ref_outs
.
append
(
ref_attn
(
q_list
[
i
],
k_list
[
i
],
v_list
[
i
],
q_
ref_
list
[
i
],
k_
ref_
list
[
i
],
v_
ref_
list
[
i
],
causal
=
causal
,
window_size
=
window_size
,
softmax_scale
=
softmax_scale
,
...
...
@@ -1369,7 +1475,19 @@ def test_unified_attn_2d(
ref_out
=
torch
.
cat
(
ref_outs
,
dim
=
0
)
# ---- CUDA kernel ----
cuda_out
,
cuda_lse
=
varlen_fwd_unified
(
expected_hg_symbol
=
None
if
seqlen_q
>=
16
:
if
not
(
not
use_fp8
and
d
==
256
and
d_v
==
256
):
expected_hg_symbol
=
"hg_prefix_prefill_varlen_fwd"
elif
use_fp8
or
seqlen_q
==
1
or
1
<
seqlen_q
<
16
:
expected_hg_symbol
=
"hg_prefix_decode_varlen_fwd"
varlen_runner
=
(
varlen_fwd_unified
if
expected_hg_symbol
is
None
else
lambda
*
args
,
**
kwargs
:
varlen_fwd_unified_expect_hg
(
expected_hg_symbol
,
*
args
,
**
kwargs
)
)
cuda_out
,
cuda_lse
=
varlen_runner
(
q_varlen
,
k_cache
,
v_cache
,
cu_seqlens_q
,
seqused_k
,
block_table
,
max_seqlen_q
=
seqlen_q
,
...
...
@@ -1384,6 +1502,9 @@ def test_unified_attn_2d(
s_aux
=
sinks
,
mm_prefix_range
=
mm_prefix_range
,
return_softmax_lse
=
True
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
# # ---- Triton kernel ----
...
...
@@ -1416,15 +1537,164 @@ def test_unified_attn_2d(
# triton_max_diff = (triton_out - ref_out).abs().max().item()
print
(
f
"
\n
[
{
dtype
}
| causal=
{
causal
}
|
{
mha_type
}
| bs=
{
batch_size
}
"
f
"sq=
{
seqlen_q
}
sk=
{
seqlen_k
}
blk=
{
block_size
}
| "
f
"
\n
[
{
dtype
}
|
fp8=
{
use_fp8
}
|
causal=
{
causal
}
|
{
mha_type
}
| bs=
{
batch_size
}
"
f
"sq=
{
seqlen_q
}
sk=
{
seqlen_k
}
d=
{
d
}
/
{
d_v
}
blk=
{
block_size
}
| "
f
"alibi_sqrt=
{
use_alibi_sqrt
}
qq_bias=
{
use_qq_bias
}
"
f
"sinks=
{
use_sinks
}
mm_prefix=
{
use_mm_prefix
}
]"
f
"
\n
CUDA max_diff=
{
cuda_max_diff
:.
4
e
}
"
# f"\n Triton max_diff={triton_max_diff:.4e}"
)
cal_diff
(
cuda_out
,
ref_out
,
"out"
)
cutlass_b16_256_prefill
=
(
not
use_fp8
and
seqlen_q
>=
16
and
d
==
256
and
d_v
==
256
)
cal_diff
(
cuda_out
,
ref_out
,
"out"
,
use_fp8
=
use_fp8
,
cos_threshold
=
1e-3
if
cutlass_b16_256_prefill
else
None
,
)
@
pytest
.
mark
.
parametrize
(
"batch_size,seqlen_q,seqlen_k,causal,window_size"
,
[
pytest
.
param
(
1
,
256
,
4096
,
True
,
(
-
1
,
-
1
),
id
=
"causal-large-sq"
),
pytest
.
param
(
2
,
257
,
4103
,
True
,
(
-
1
,
-
1
),
id
=
"causal-unaligned-sq-sk"
),
pytest
.
param
(
4
,
512
,
8193
,
True
,
(
-
1
,
-
1
),
id
=
"causal-large-bs-sq-unaligned-sk"
),
# The torch reference materializes dense [heads, seq_q, seq_kv] scores,
# so keep 4K/8K correctness at bs=1 while still hitting long prefill.
pytest
.
param
(
1
,
4096
,
4096
,
True
,
(
-
1
,
-
1
),
id
=
"causal-sq4096"
),
pytest
.
param
(
1
,
8192
,
8192
,
True
,
(
-
1
,
-
1
),
id
=
"causal-sq8192"
),
pytest
.
param
(
1
,
4097
,
8193
,
True
,
(
-
1
,
-
1
),
id
=
"causal-unaligned-sq4097-sk8193"
),
pytest
.
param
(
1
,
17
,
129
,
False
,
(
511
,
0
),
id
=
"swa-lower-boundary-unaligned"
),
pytest
.
param
(
3
,
65
,
1025
,
False
,
(
511
,
0
),
id
=
"swa-mid-unaligned"
),
pytest
.
param
(
2
,
513
,
1537
,
False
,
(
511
,
0
),
id
=
"swa-large-unaligned-sq-sk"
),
pytest
.
param
(
1
,
4096
,
8193
,
False
,
(
511
,
0
),
id
=
"swa-sq4096-sk8193"
),
pytest
.
param
(
1
,
8192
,
8193
,
False
,
(
511
,
0
),
id
=
"swa-sq8192-sk8193"
),
],
)
def
test_unified_attn_fp8_192x128_prefill_corner_cases
(
batch_size
,
seqlen_q
,
seqlen_k
,
causal
,
window_size
):
test_unified_attn_2d
(
batch_size
=
batch_size
,
seqlen_q
=
seqlen_q
,
seqlen_k
=
seqlen_k
,
block_size
=
UNIFIED_BLOCK_SIZE
,
d
=
192
,
d_v
=
128
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
0.0
,
mha_type
=
"gqa"
,
dtype
=
torch
.
bfloat16
,
use_alibi_sqrt
=
False
,
use_qq_bias
=
False
,
use_sinks
=
False
,
use_mm_prefix
=
False
,
use_fp8
=
True
,
)
@
pytest
.
mark
.
parametrize
(
"batch_size,seqlen_q,seqlen_k,causal,window_size"
,
[
pytest
.
param
(
1
,
17
,
129
,
True
,
(
-
1
,
-
1
),
id
=
"causal-lower-boundary-unaligned"
),
pytest
.
param
(
2
,
65
,
1025
,
True
,
(
-
1
,
-
1
),
id
=
"causal-mid-unaligned"
),
pytest
.
param
(
1
,
256
,
4097
,
True
,
(
-
1
,
-
1
),
id
=
"causal-large-unaligned-sk"
),
pytest
.
param
(
1
,
17
,
129
,
False
,
(
511
,
0
),
id
=
"swa-lower-boundary-unaligned"
),
pytest
.
param
(
2
,
65
,
1025
,
False
,
(
511
,
0
),
id
=
"swa-mid-unaligned"
),
pytest
.
param
(
1
,
256
,
4097
,
False
,
(
511
,
0
),
id
=
"swa-large-unaligned-sk"
),
],
)
def
test_unified_attn_fp8_256x256_prefill_corner_cases
(
batch_size
,
seqlen_q
,
seqlen_k
,
causal
,
window_size
):
test_unified_attn_2d
(
batch_size
=
batch_size
,
seqlen_q
=
seqlen_q
,
seqlen_k
=
seqlen_k
,
block_size
=
UNIFIED_BLOCK_SIZE
,
d
=
256
,
d_v
=
256
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
0.0
,
mha_type
=
"gqa"
,
dtype
=
torch
.
bfloat16
,
use_alibi_sqrt
=
False
,
use_qq_bias
=
False
,
use_sinks
=
False
,
use_mm_prefix
=
False
,
use_fp8
=
True
,
)
@
pytest
.
mark
.
parametrize
(
"batch_size,seqlen_q,seqlen_k,causal,window_size"
,
[
pytest
.
param
(
1
,
1
,
4096
,
True
,
(
-
1
,
-
1
),
id
=
"decode-sq1-large-sk"
),
pytest
.
param
(
2
,
4
,
2048
,
True
,
(
-
1
,
-
1
),
id
=
"decode-mtp4-large-sk"
),
pytest
.
param
(
1
,
17
,
129
,
True
,
(
-
1
,
-
1
),
id
=
"prefill-causal-lower-boundary"
),
pytest
.
param
(
2
,
65
,
1025
,
False
,
(
511
,
0
),
id
=
"prefill-swa-unaligned"
),
],
)
def
test_unified_attn_fp8_192x128_sinks_corner_cases
(
batch_size
,
seqlen_q
,
seqlen_k
,
causal
,
window_size
):
test_unified_attn_2d
(
batch_size
=
batch_size
,
seqlen_q
=
seqlen_q
,
seqlen_k
=
seqlen_k
,
block_size
=
UNIFIED_BLOCK_SIZE
,
d
=
192
,
d_v
=
128
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
0.0
,
mha_type
=
"gqa"
,
dtype
=
torch
.
bfloat16
,
use_alibi_sqrt
=
False
,
use_qq_bias
=
False
,
use_sinks
=
True
,
use_mm_prefix
=
False
,
use_fp8
=
True
,
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"d,d_v"
,
[(
128
,
128
),
(
192
,
128
)])
@
pytest
.
mark
.
parametrize
(
"batch_size,seqlen_q,seqlen_k,causal,window_size"
,
[
pytest
.
param
(
1
,
64
,
257
,
True
,
(
-
1
,
-
1
),
id
=
"prefill-causal-unaligned"
),
pytest
.
param
(
2
,
65
,
1025
,
False
,
(
511
,
0
),
id
=
"prefill-swa-unaligned"
),
pytest
.
param
(
16
,
1
,
2048
,
True
,
(
-
1
,
-
1
),
id
=
"decode-sq1"
),
pytest
.
param
(
16
,
4
,
2048
,
False
,
(
-
1
,
-
1
),
id
=
"decode-mtp4-noncausal"
),
pytest
.
param
(
16
,
4
,
2048
,
False
,
(
511
,
0
),
id
=
"decode-mtp4-swa"
),
],
)
def
test_unified_attn_b16_page64_regression
(
batch_size
,
seqlen_q
,
seqlen_k
,
causal
,
window_size
,
d
,
d_v
,
block_size
):
test_unified_attn_2d
(
batch_size
=
batch_size
,
seqlen_q
=
seqlen_q
,
seqlen_k
=
seqlen_k
,
block_size
=
block_size
,
d
=
d
,
d_v
=
d_v
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
0.0
,
mha_type
=
"gqa"
,
dtype
=
torch
.
bfloat16
,
use_alibi_sqrt
=
False
,
use_qq_bias
=
False
,
use_sinks
=
False
,
use_mm_prefix
=
False
,
use_fp8
=
False
,
)
# ---------------------------------------------------------------------------
...
...
@@ -1440,8 +1710,8 @@ def benchmark_unified_attention():
torch
.
manual_seed
(
42
)
dtype
=
torch
.
float16
d
=
256
block_size
=
320
d
=
128
block_size
=
UNIFIED_BLOCK_SIZE
warmup
=
10
repeat
=
50
...
...
@@ -1452,30 +1722,35 @@ def benchmark_unified_attention():
MAX_MM_RANGES
=
2
# GQA
nheads
=
8
nheads_k
=
1
nheads
=
24
nheads_k
=
2
# workload shapes
shapes
=
[
# (8, 2048, 2048),
# (4, 2048, 2048),
(
4
,
1
,
2048
),
(
1
,
4
,
51200
),
(
2
,
4
,
51200
),
(
4
,
4
,
51200
),
(
8
,
4
,
51200
),
(
16
,
4
,
51200
),
(
32
,
4
,
51200
),
# (4, 2048, 4096),
]
# feature configs (C A Q S P)
feature_configs
=
[
(
1
,
0
,
0
,
0
,
0
),
(
1
,
1
,
0
,
0
,
0
),
(
1
,
0
,
1
,
0
,
0
),
(
1
,
0
,
0
,
1
,
0
),
(
1
,
0
,
0
,
0
,
1
),
(
1
,
1
,
1
,
0
,
0
),
(
1
,
1
,
0
,
1
,
0
),
(
1
,
1
,
0
,
0
,
1
),
(
1
,
1
,
1
,
1
,
0
),
(
1
,
1
,
1
,
0
,
1
),
(
1
,
1
,
1
,
1
,
1
),
#
(1,1,0,0,0),
#
(1,0,1,0,0),
#
(1,0,0,1,0),
#
(1,0,0,0,1),
#
(1,1,1,0,0),
#
(1,1,0,1,0),
#
(1,1,0,0,1),
#
(1,1,1,1,0),
#
(1,1,1,0,1),
#
(1,1,1,1,1),
]
print
(
"
\n
Unified Attention GQA Benchmark"
)
...
...
@@ -1485,7 +1760,8 @@ def benchmark_unified_attention():
f
"
{
'BS'
:
>
3
}
{
'SQ'
:
>
6
}
{
'SK'
:
>
6
}
| "
f
"
{
'C'
:
>
1
}
{
'A'
:
>
1
}
{
'Q'
:
>
1
}
{
'S'
:
>
1
}
{
'P'
:
>
1
}
| "
f
"
{
'CUDA(ms)'
:
>
10
}
{
'Triton(ms)'
:
>
11
}
| "
f
"
{
'CUDA TFLOPS'
:
>
11
}
{
'Triton TFLOPS'
:
>
13
}
|
{
'Speedup'
:
>
8
}
"
f
"
{
'CUDA TFLOPS'
:
>
11
}
{
'Triton TFLOPS'
:
>
13
}
| "
f
"
{
'CUDA(GB/s)'
:
>
10
}
{
'Triton(GB/s)'
:
>
12
}
|
{
'Speedup'
:
>
8
}
"
)
print
(
"-"
*
120
)
...
...
@@ -1548,6 +1824,15 @@ def benchmark_unified_attention():
triton_out
=
torch
.
zeros_like
(
q_varlen
)
total_bytes
=
estimate_unified_attention_bytes
(
batch_size
,
seqlen_q
,
seqlen_k
,
nheads
,
nheads_k
,
d
,
block_size
,
q_bytes
=
torch
.
finfo
(
dtype
).
bits
//
8
,
k_bytes
=
torch
.
finfo
(
dtype
).
bits
//
8
,
v_bytes
=
torch
.
finfo
(
dtype
).
bits
//
8
,
d_v
=
d
,
window_size
=
window_size
,
)
for
C
,
A
,
Q
,
S
,
P
in
feature_configs
:
causal
=
bool
(
C
)
...
...
@@ -1574,7 +1859,8 @@ def benchmark_unified_attention():
sinks
=
None
if
use_sinks
:
sinks
=
torch
.
randn
(
nheads
,
device
=
device
,
dtype
=
dtype
)
sink_dtype
=
torch
.
bfloat16
if
use_fp8
else
torch
.
float32
sinks
=
torch
.
randn
(
nheads
,
device
=
device
,
dtype
=
sink_dtype
)
mm_prefix_range
=
None
if
use_mm_prefix
:
...
...
@@ -1655,6 +1941,9 @@ def benchmark_unified_attention():
# FLOPs
flops
=
4.0
*
batch_size
*
nheads
*
seqlen_q
*
seqlen_k
*
d
cuda_bandwidth
=
total_bytes
/
1e9
/
cuda_ms
*
1000
triton_bandwidth
=
total_bytes
/
1e9
/
triton_ms
*
1000
cuda_tflops
=
flops
/
cuda_ms
/
1e9
triton_tflops
=
flops
/
triton_ms
/
1e9
...
...
@@ -1663,11 +1952,277 @@ def benchmark_unified_attention():
f
"
{
C
}
{
A
}
{
Q
}
{
S
}
{
P
}
| "
f
"
{
cuda_ms
:
10.3
f
}
{
triton_ms
:
11.3
f
}
| "
f
"
{
cuda_tflops
:
11.2
f
}
{
triton_tflops
:
13.2
f
}
| "
f
"
{
cuda_bandwidth
:
10.2
f
}
{
triton_bandwidth
:
12.2
f
}
| "
f
"
{
triton_ms
/
cuda_ms
:
8.2
f
}
x"
)
print
(
"="
*
120
)
def
benchmark_hg_b16_fp8_pa
():
device
=
torch
.
device
(
"cuda"
)
torch
.
manual_seed
(
42
)
dtype
=
torch
.
bfloat16
fp8_dtype
=
current_platform
.
fp8_dtype
()
d
=
192
d_v
=
128
block_size
=
UNIFIED_BLOCK_SIZE
warmup
=
10
repeat
=
50
nheads
=
16
nheads_k
=
2
softcap
=
0.0
shapes
=
[
(
bs
,
seqlen_q
,
51200
)
for
seqlen_q
in
(
1
,
4
)
for
bs
in
(
1
,
2
,
4
,
8
,
16
,
32
,
64
)
]
shapes
+=
[
(
bs
,
seqlen_q
,
4096
)
for
seqlen_q
in
(
32
,
128
,
256
,
512
)
for
bs
in
(
1
,
2
,
4
)
]
shapes
+=
[
(
1
,
257
,
4103
),
(
2
,
513
,
8193
),
(
1
,
4096
,
4096
),
(
1
,
4097
,
8193
),
(
1
,
8192
,
8192
),
(
1
,
8192
,
8193
),
]
windows
=
[(
-
1
,
-
1
),
(
511
,
0
)]
def
time_fn
(
fn
):
for
_
in
range
(
warmup
):
fn
()
torch
.
cuda
.
synchronize
()
start
=
time
.
perf_counter
()
for
_
in
range
(
repeat
):
fn
()
torch
.
cuda
.
synchronize
()
return
(
time
.
perf_counter
()
-
start
)
/
repeat
*
1000
def
unwrap_out
(
result
):
if
isinstance
(
result
,
(
tuple
,
list
)):
return
result
[
0
]
return
result
def
diff_stats
(
x
,
y
):
x
=
x
.
float
()
y
=
y
.
float
()
denom
=
torch
.
clamp
((
x
*
x
+
y
*
y
).
sum
(),
min
=
1e-12
)
cos_diff
=
1
-
2
*
(
x
*
y
).
sum
()
/
denom
max_diff
=
(
x
-
y
).
abs
().
max
()
return
cos_diff
.
item
(),
max_diff
.
item
()
max_ref_scores
=
int
(
os
.
getenv
(
"BENCH_REF_MAX_SCORES"
,
"20000000"
))
print
(
"
\n
HG Unified PA BF16/FP8 Benchmark"
)
print
(
"="
*
170
)
print
(
f
"
{
'BS'
:
>
3
}
{
'SQ'
:
>
3
}
{
'SK'
:
>
6
}
{
'D'
:
>
7
}
{
'WINDOW'
:
>
12
}
| "
f
"
{
'HG_OK'
:
>
5
}
{
'HG(ms)'
:
>
8
}
{
'TRI(ms)'
:
>
8
}
{
'FP8(ms)'
:
>
8
}
| "
f
"
{
'FP8/HG'
:
>
8
}
{
'FP8/TRI'
:
>
8
}
{
'HG/TRI'
:
>
8
}
| "
f
"
{
'REF_cos'
:
>
9
}
{
'REF_max'
:
>
9
}
|
{
'FP8 GB/s'
:
>
9
}
{
'NOTE'
:
>
18
}
"
)
print
(
"-"
*
170
)
summary
=
{
"total"
:
0
,
"hg_ok"
:
0
,
"hg_fail"
:
0
,
"fp8_hg_speedups"
:
[],
"fp8_triton_speedups"
:
[],
}
for
window_size
in
windows
:
for
batch_size
,
seqlen_q
,
seqlen_k
in
shapes
:
summary
[
"total"
]
+=
1
causal
=
window_size
==
(
-
1
,
-
1
)
softmax_scale
=
d
**
(
-
0.5
)
q_list
=
[
torch
.
randn
(
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
for
_
in
range
(
batch_size
)]
k_list
=
[
torch
.
randn
(
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
for
_
in
range
(
batch_size
)]
v_list
=
[
torch
.
randn
(
seqlen_k
,
nheads_k
,
d_v
,
device
=
device
,
dtype
=
dtype
)
for
_
in
range
(
batch_size
)]
q_b16
=
torch
.
cat
(
q_list
,
dim
=
0
)
k_b16
,
v_b16
,
block_table
=
make_paged_kv
(
k_list
,
v_list
,
block_size
,
device
,
dtype
)
q_fp8
=
q_b16
.
to
(
fp8_dtype
)
k_fp8
=
k_b16
.
to
(
fp8_dtype
)
v_fp8
=
v_b16
.
to
(
fp8_dtype
)
cu_seqlens_q
=
torch
.
zeros
(
batch_size
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
cu_seqlens_q
[
1
:]
=
torch
.
cumsum
(
torch
.
tensor
([
seqlen_q
]
*
batch_size
,
dtype
=
torch
.
int32
),
dim
=
0
)
seqused_k
=
torch
.
tensor
([
seqlen_k
]
*
batch_size
,
device
=
device
,
dtype
=
torch
.
int32
)
q_descale
=
torch
.
ones
((
batch_size
,
nheads
),
device
=
device
,
dtype
=
torch
.
float32
)
k_descale
=
torch
.
ones
((
batch_size
,
nheads_k
),
device
=
device
,
dtype
=
torch
.
float32
)
v_descale
=
torch
.
ones
((
batch_size
,
nheads_k
),
device
=
device
,
dtype
=
torch
.
float32
)
expected_hg_symbol
=
(
"hg_prefix_prefill_varlen_fwd"
if
seqlen_q
>=
16
else
"hg_prefix_decode_varlen_fwd"
)
def
run_b16_hg_checked
():
return
unwrap_out
(
varlen_fwd_unified_expect_hg
(
expected_hg_symbol
,
q_b16
,
k_b16
,
v_b16
,
cu_seqlens_q
,
seqused_k
,
block_table
,
max_seqlen_q
=
seqlen_q
,
max_seqlen_k
=
seqlen_k
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
return_softmax_lse
=
False
,
))
def
run_b16_hg
():
return
unwrap_out
(
varlen_fwd_unified
(
q_b16
,
k_b16
,
v_b16
,
cu_seqlens_q
,
seqused_k
,
block_table
,
max_seqlen_q
=
seqlen_q
,
max_seqlen_k
=
seqlen_k
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
return_softmax_lse
=
False
,
))
def
run_fp8_checked
():
return
unwrap_out
(
varlen_fwd_unified_expect_hg
(
expected_hg_symbol
,
q_fp8
,
k_fp8
,
v_fp8
,
cu_seqlens_q
,
seqused_k
,
block_table
,
max_seqlen_q
=
seqlen_q
,
max_seqlen_k
=
seqlen_k
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
return_softmax_lse
=
False
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
))
def
run_fp8
():
return
unwrap_out
(
varlen_fwd_unified
(
q_fp8
,
k_fp8
,
v_fp8
,
cu_seqlens_q
,
seqused_k
,
block_table
,
max_seqlen_q
=
seqlen_q
,
max_seqlen_k
=
seqlen_k
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
return_softmax_lse
=
False
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
))
triton_out
=
torch
.
empty
(
(
q_b16
.
shape
[
0
],
nheads
,
d_v
),
device
=
device
,
dtype
=
dtype
)
def
run_triton_b16
():
unified_attention
(
q
=
q_b16
,
k
=
k_b16
,
v
=
v_b16
,
out
=
triton_out
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
seqlen_q
,
seqused_k
=
seqused_k
,
max_seqlen_k
=
seqlen_k
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
block_table
=
block_table
,
softcap
=
softcap
,
q_descale
=
None
,
k_descale
=
None
,
v_descale
=
None
,
seq_threshold_3D
=
128
,
)
return
triton_out
note
=
""
hg_ms
=
float
(
"nan"
)
fp8_over_hg
=
float
(
"nan"
)
hg_over_triton
=
float
(
"nan"
)
hg_cos
=
float
(
"nan"
)
hg_max
=
float
(
"nan"
)
triton_ms
=
time_fn
(
run_triton_b16
)
run_fp8_checked
()
torch
.
cuda
.
synchronize
()
fp8_ms
=
time_fn
(
run_fp8
)
try
:
b16_hg_out
=
run_b16_hg_checked
()
torch
.
cuda
.
synchronize
()
num_ref_scores
=
batch_size
*
seqlen_q
*
seqlen_k
*
nheads
hg_correct
=
True
if
num_ref_scores
<=
max_ref_scores
:
ref_out
=
torch
.
cat
([
ref_attn
(
q_list
[
i
],
k_list
[
i
],
v_list
[
i
],
causal
=
causal
,
window_size
=
window_size
,
softmax_scale
=
softmax_scale
,
softcap
=
softcap
,
)
for
i
in
range
(
batch_size
)
],
dim
=
0
)
hg_cos
,
hg_max
=
diff_stats
(
b16_hg_out
,
ref_out
)
hg_correct
=
hg_cos
<
1e-3
if
not
hg_correct
:
note
=
"HG_BF16_REF_DIFF"
else
:
note
=
"REF_SKIP"
if
not
hg_correct
:
summary
[
"hg_fail"
]
+=
1
else
:
hg_ms
=
time_fn
(
run_b16_hg
)
fp8_over_hg
=
hg_ms
/
fp8_ms
hg_over_triton
=
triton_ms
/
hg_ms
summary
[
"hg_ok"
]
+=
1
summary
[
"fp8_hg_speedups"
].
append
(
fp8_over_hg
)
except
Exception
as
exc
:
summary
[
"hg_fail"
]
+=
1
note
=
type
(
exc
).
__name__
fp8_over_triton
=
triton_ms
/
fp8_ms
summary
[
"fp8_triton_speedups"
].
append
(
fp8_over_triton
)
fp8_bytes
=
estimate_unified_attention_bytes
(
batch_size
,
seqlen_q
,
seqlen_k
,
nheads
,
nheads_k
,
d
,
block_size
,
q_bytes
=
1
,
k_bytes
=
1
,
v_bytes
=
1
,
out_bytes
=
2
,
d_v
=
d_v
,
window_size
=
window_size
,
)
print
(
f
"
{
batch_size
:
3
d
}
{
seqlen_q
:
3
d
}
{
seqlen_k
:
6
d
}
{
str
(
d
)
+
'/'
+
str
(
d_v
):
>
7
}
{
str
(
window_size
):
>
12
}
| "
f
"
{
str
(
math
.
isfinite
(
hg_ms
)):
>
5
}
{
hg_ms
:
8.3
f
}
{
triton_ms
:
8.3
f
}
{
fp8_ms
:
8.3
f
}
| "
f
"
{
fp8_over_hg
:
8.2
f
}
{
fp8_over_triton
:
8.2
f
}
{
hg_over_triton
:
8.2
f
}
| "
f
"
{
hg_cos
:
9.2
e
}
{
hg_max
:
9.2
e
}
| "
f
"
{
fp8_bytes
/
1e9
/
fp8_ms
*
1000
:
9.2
f
}
{
note
:
>
18
}
"
)
print
(
"-"
*
170
)
if
summary
[
"fp8_hg_speedups"
]:
hg_speedups
=
torch
.
tensor
(
summary
[
"fp8_hg_speedups"
],
dtype
=
torch
.
float32
)
print
(
"HG_BF16 baseline summary: "
f
"ok=
{
summary
[
'hg_ok'
]
}
/
{
summary
[
'total'
]
}
"
f
"fail=
{
summary
[
'hg_fail'
]
}
"
f
"fp8/hg mean=
{
hg_speedups
.
mean
().
item
():.
3
f
}
"
f
"median=
{
hg_speedups
.
median
().
item
():.
3
f
}
"
f
"min=
{
hg_speedups
.
min
().
item
():.
3
f
}
"
f
"max=
{
hg_speedups
.
max
().
item
():.
3
f
}
"
)
triton_speedups
=
torch
.
tensor
(
summary
[
"fp8_triton_speedups"
],
dtype
=
torch
.
float32
)
print
(
"Triton BF16 reference summary: "
f
"fp8/triton mean=
{
triton_speedups
.
mean
().
item
():.
3
f
}
"
f
"median=
{
triton_speedups
.
median
().
item
():.
3
f
}
"
f
"min=
{
triton_speedups
.
min
().
item
():.
3
f
}
"
f
"max=
{
triton_speedups
.
max
().
item
():.
3
f
}
"
)
print
(
"="
*
170
)
if
__name__
==
"__main__"
:
benchmark_unified_attention
()
\ No newline at end of file
if
os
.
getenv
(
"RUN_UNIFIED_BENCHMARK"
)
==
"1"
:
benchmark_hg_b16_fp8_pa
()
else
:
test_unified_attn_2d
(
batch_size
=
1
,
seqlen_q
=
1
,
seqlen_k
=
40960
,
block_size
=
UNIFIED_BLOCK_SIZE
,
d
=
192
,
d_v
=
128
,
causal
=
False
,
window_size
=
(
511
,
0
),
softcap
=
0.0
,
mha_type
=
"gqa"
,
dtype
=
torch
.
bfloat16
,
use_alibi_sqrt
=
False
,
use_qq_bias
=
False
,
use_sinks
=
False
,
use_mm_prefix
=
False
,
use_fp8
=
True
)
Prev
1
…
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment