Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
aiter
Commits
f233de81
Commit
f233de81
authored
Apr 17, 2026
by
Xiaowei.zhang
Browse files
[SYNC] Code sync.
parent
1893a1e0
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
314 additions
and
2 deletions
+314
-2
op_tests/triton_tests/test_extend_attention.py
op_tests/triton_tests/test_extend_attention.py
+307
-1
rebuild_aiter.sh
rebuild_aiter.sh
+2
-0
setup.py
setup.py
+5
-1
No files found.
op_tests/triton_tests/test_extend_attention.py
View file @
f233de81
# SPDX-License-Identifier: MIT
# SPDX-License-Identifier: MIT
import
torch
import
torch
import
torch.nn.functional
as
F
import
pytest
import
pytest
from
aiter.ops.triton.extend_attention
import
extend_attention_fwd
from
aiter.ops.triton.extend_attention
import
extend_attention_fwd
def
extend_attention_fwd_torch_swa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
o
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
kv_indptr
:
torch
.
Tensor
,
kv_indices
:
torch
.
Tensor
,
sliding_window_size
:
int
,
*
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
sm_scale
:
float
|
None
=
None
,
):
"""Reference for causal + sliding-window extend attention (sglang test style).
Runs the heavy matmul/softmax on CPU float32 for numerical stability and to avoid
ROCm aborts on large bf16 einsum after GPU kernels.
v2 与 Triton 一致:``k_scale`` / ``v_scale`` **只作用在 prefix(cache)键位**;
extend 段 logits 与 V 不额外乘这两个标量。
"""
B
=
qo_indptr
.
size
(
0
)
-
1
_
,
H_Q
,
D
=
q
.
shape
_
,
H_KV
,
_
=
k
.
shape
group_size
=
H_Q
//
H_KV
scale
=
float
(
sm_scale
)
if
sm_scale
is
not
None
else
1.0
/
D
**
0.5
out_dev
=
o
.
device
out_dtype
=
o
.
dtype
for
i
in
range
(
B
):
q_start
=
int
(
qo_indptr
[
i
].
item
())
q_end
=
int
(
qo_indptr
[
i
+
1
].
item
())
kv_start
=
int
(
kv_indptr
[
i
].
item
())
kv_end
=
int
(
kv_indptr
[
i
+
1
].
item
())
prefix_indices
=
kv_indices
[
kv_start
:
kv_end
]
k_prefix
=
k_cache
[
prefix_indices
]
v_prefix
=
v_cache
[
prefix_indices
]
k_extend
=
k
[
q_start
:
q_end
]
v_extend
=
v
[
q_start
:
q_end
]
q_extend
=
q
[
q_start
:
q_end
]
k_full
=
torch
.
cat
([
k_prefix
,
k_extend
],
dim
=
0
)
v_full
=
torch
.
cat
([
v_prefix
,
v_extend
],
dim
=
0
)
if
group_size
!=
1
:
k_full_hq
=
k_full
.
repeat_interleave
(
group_size
,
dim
=
1
)
v_full_hq
=
v_full
.
repeat_interleave
(
group_size
,
dim
=
1
)
else
:
k_full_hq
=
k_full
v_full_hq
=
v_full
prefix_len
=
k_prefix
.
size
(
0
)
extend_len
=
k_extend
.
size
(
0
)
total_len
=
prefix_len
+
extend_len
q_e
=
q_extend
.
detach
().
float
().
cpu
()
k_h
=
k_full_hq
.
detach
().
float
().
cpu
()
v_h
=
v_full_hq
.
detach
().
float
().
cpu
()
pos_keys
=
torch
.
arange
(
total_len
)
t
=
prefix_len
+
torch
.
arange
(
extend_len
)
causal_mask
=
pos_keys
.
unsqueeze
(
0
)
<=
t
.
unsqueeze
(
1
)
if
sliding_window_size
is
not
None
and
sliding_window_size
>
0
:
start
=
(
t
-
sliding_window_size
).
clamp_min
(
0
)
else
:
start
=
torch
.
zeros_like
(
t
)
window_mask
=
pos_keys
.
unsqueeze
(
0
)
>=
start
.
unsqueeze
(
1
)
final_mask
=
causal_mask
&
window_mask
attn_scores
=
torch
.
einsum
(
"qhd,khd->qhk"
,
q_e
,
k_h
)
*
scale
if
k_scale
!=
1.0
:
attn_scores
[:,
:,
:
prefix_len
]
=
attn_scores
[:,
:,
:
prefix_len
]
*
k_scale
attn_scores
=
attn_scores
.
masked_fill
(
~
final_mask
.
unsqueeze
(
1
),
float
(
"-inf"
))
attn_weights
=
F
.
softmax
(
attn_scores
,
dim
=-
1
)
if
v_scale
!=
1.0
:
v_prefix
=
v_h
[:
prefix_len
]
*
v_scale
v_h_scaled
=
torch
.
cat
([
v_prefix
,
v_h
[
prefix_len
:]],
dim
=
0
)
else
:
v_h_scaled
=
v_h
out_cpu
=
torch
.
einsum
(
"qhk,khd->qhd"
,
attn_weights
,
v_h_scaled
)
o
[
q_start
:
q_end
]
=
out_cpu
.
to
(
device
=
out_dev
,
dtype
=
out_dtype
)
def
input_helper
(
def
input_helper
(
B
,
B
,
H
,
H
,
...
@@ -269,6 +363,218 @@ def test_op_fwd(
...
@@ -269,6 +363,218 @@ def test_op_fwd(
torch
.
testing
.
assert_close
(
ref_out
,
tri_out
,
rtol
=
2e-2
,
atol
=
2e-2
)
torch
.
testing
.
assert_close
(
ref_out
,
tri_out
,
rtol
=
2e-2
,
atol
=
2e-2
)
def
test_extend_attention_v2_identity_scales_match_v1
():
"""v2 with fp32 1.0 scales should match v1 (k_scale/v_scale None)."""
device
=
"cuda"
dtype
=
torch
.
float16
torch
.
manual_seed
(
0
)
(
q_extend
,
k_extend
,
v_extend
,
k_buffer
,
v_buffer
,
kv_indptr
,
kv_indices
,
qo_indptr
,
custom_mask
,
mask_indptr
,
max_len_extend
,
)
=
input_helper
(
2
,
4
,
64
,
32
,
128
,
64
,
128
,
dtype
,
device
,
"normal"
)
out_v1
=
torch
.
empty
(
(
*
q_extend
.
shape
[:
-
1
],
v_extend
.
shape
[
-
1
]),
dtype
=
q_extend
.
dtype
,
device
=
device
,
)
out_v2
=
torch
.
empty_like
(
out_v1
)
extend_attention_fwd
(
q_extend
,
k_extend
,
v_extend
,
out_v1
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
True
,
mask_indptr
,
max_len_extend
,
sm_scale
=
None
,
logit_cap
=
0.0
,
skip_prefix_custom_mask
=
True
,
config
=
None
,
)
extend_attention_fwd
(
q_extend
,
k_extend
,
v_extend
,
out_v2
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
True
,
mask_indptr
,
max_len_extend
,
sm_scale
=
None
,
logit_cap
=
0.0
,
skip_prefix_custom_mask
=
True
,
config
=
None
,
k_scale
=
1.0
,
v_scale
=
1.0
,
sliding_window_size
=-
1
,
sinks
=
None
,
window_kv_offsets
=
None
,
xai_temperature_len
=-
1
,
)
torch
.
testing
.
assert_close
(
out_v1
,
out_v2
,
rtol
=
2e-2
,
atol
=
2e-2
)
def
_build_extend_inputs_swa_style
(
B
,
N_CTX
,
H_Q
,
H_KV
,
D
,
device
,
dtype
):
"""Layout aligned with sglang test_triton_attention_kernels sliding-window setup."""
b_seq_len_prefix
=
torch
.
randint
(
1
,
N_CTX
//
2
,
(
B
,),
dtype
=
torch
.
int32
,
device
=
device
)
b_seq_len_extend
=
torch
.
randint
(
1
,
N_CTX
//
2
,
(
B
,),
dtype
=
torch
.
int32
,
device
=
device
)
b_seq_len
=
b_seq_len_prefix
+
b_seq_len_extend
b_start_loc
=
torch
.
zeros
((
B
,),
dtype
=
torch
.
int32
,
device
=
device
)
b_start_loc
[
1
:]
=
torch
.
cumsum
(
b_seq_len
[:
-
1
],
0
)
b_start_loc_extend
=
torch
.
zeros
((
B
,),
dtype
=
torch
.
int32
,
device
=
device
)
b_start_loc_extend
[
1
:]
=
torch
.
cumsum
(
b_seq_len_extend
[:
-
1
],
0
)
kv_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
kv_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_prefix
[:
B
],
dim
=
0
)
kv_indices
=
torch
.
zeros
(
(
b_seq_len_prefix
.
sum
().
item
(),),
dtype
=
torch
.
int32
,
device
=
device
)
for
i
in
range
(
B
):
kv_indices
[
kv_indptr
[
i
]
:
kv_indptr
[
i
+
1
]]
=
torch
.
arange
(
b_start_loc
[
i
],
b_start_loc
[
i
]
+
b_seq_len_prefix
[
i
],
device
=
device
)
total_token_num
=
torch
.
sum
(
b_seq_len
).
item
()
extend_token_num
=
torch
.
sum
(
b_seq_len_extend
).
item
()
k_buffer
=
torch
.
empty
(
(
total_token_num
,
H_KV
,
D
),
dtype
=
dtype
,
device
=
device
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
v_buffer
=
torch
.
empty
(
(
total_token_num
,
H_KV
,
D
),
dtype
=
dtype
,
device
=
device
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
k_extend
=
torch
.
empty
((
extend_token_num
,
H_KV
,
D
),
dtype
=
dtype
,
device
=
device
)
v_extend
=
torch
.
empty
((
extend_token_num
,
H_KV
,
D
),
dtype
=
dtype
,
device
=
device
)
q_extend
=
torch
.
empty
((
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
device
)
for
i
in
range
(
B
):
extend_start_in_buffer
=
b_start_loc
[
i
]
+
b_seq_len_prefix
[
i
]
extend_end_in_buffer
=
b_start_loc
[
i
]
+
b_seq_len
[
i
]
extend_start
=
b_start_loc_extend
[
i
]
extend_end
=
b_start_loc_extend
[
i
]
+
b_seq_len_extend
[
i
]
k_extend
[
extend_start
:
extend_end
]
=
k_buffer
[
extend_start_in_buffer
:
extend_end_in_buffer
]
v_extend
[
extend_start
:
extend_end
]
=
v_buffer
[
extend_start_in_buffer
:
extend_end_in_buffer
]
q_extend
[
extend_start
:
extend_end
]
=
torch
.
empty
(
(
b_seq_len_extend
[
i
],
H_Q
,
D
),
dtype
=
dtype
,
device
=
device
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
b_seq_len_extend
=
b_seq_len
-
b_seq_len_prefix
max_len_extend
=
torch
.
max
(
b_seq_len_extend
,
0
)[
0
].
item
()
qo_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
qo_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_extend
[:
B
],
dim
=
0
)
return
(
q_extend
,
k_extend
,
v_extend
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
max_len_extend
,
)
@
pytest
.
mark
.
parametrize
(
"window_size"
,
[
-
1
,
32
,
127
])
def
test_extend_attention_v2_sliding_window
(
window_size
):
"""v2 + sliding_window_size vs torch reference (sglang-style construction)."""
torch
.
manual_seed
(
42
)
device
=
"cuda"
dtype
=
torch
.
bfloat16
B
,
N_CTX
,
H_Q
,
H_KV
,
D
=
4
,
512
,
8
,
8
,
128
(
q_extend
,
k_extend
,
v_extend
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
max_len_extend
,
)
=
_build_extend_inputs_swa_style
(
B
,
N_CTX
,
H_Q
,
H_KV
,
D
,
device
,
dtype
)
extend_token_num
=
q_extend
.
shape
[
0
]
o_triton
=
torch
.
empty
((
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
device
)
o_torch
=
torch
.
empty
((
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
device
)
extend_attention_fwd
(
q_extend
,
k_extend
,
v_extend
,
o_triton
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
=
None
,
is_causal
=
True
,
mask_indptr
=
None
,
max_len_extend
=
max_len_extend
,
sm_scale
=
None
,
logit_cap
=
0.0
,
skip_prefix_custom_mask
=
True
,
config
=
None
,
k_scale
=
1.2
,
v_scale
=
1.2
,
sliding_window_size
=
window_size
,
sinks
=
None
,
window_kv_offsets
=
None
,
xai_temperature_len
=-
1
,
)
extend_attention_fwd_torch_swa
(
q_extend
,
k_extend
,
v_extend
,
o_torch
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
window_size
,
k_scale
=
1.2
,
v_scale
=
1.2
,
sm_scale
=
None
,
)
torch
.
testing
.
assert_close
(
o_triton
,
o_torch
,
rtol
=
2e-2
,
atol
=
2e-2
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_op_fwd
(
1
,
2
,
1024
,
1024
,
256
,
0
,
256
,
torch
.
bfloat16
,
"normal"
,
False
)
test_op_fwd
(
1
,
2
,
1024
,
1024
,
256
,
0
,
256
,
torch
.
bfloat16
,
"normal"
,
False
)
test_op_fwd
(
3
,
5
,
110
,
333
,
18
,
0
,
17
,
torch
.
float32
,
"normal"
,
True
)
test_op_fwd
(
3
,
5
,
110
,
333
,
18
,
0
,
17
,
torch
.
float32
,
"normal"
,
True
)
...
...
rebuild_aiter.sh
View file @
f233de81
...
@@ -24,6 +24,8 @@ echo "### start rebuild aiter..."
...
@@ -24,6 +24,8 @@ echo "### start rebuild aiter..."
# 3)AITER_LOG_MORE=2: Python/JIT 流程、模块触发、参数和更高层日志
# 3)AITER_LOG_MORE=2: Python/JIT 流程、模块触发、参数和更高层日志
# 4) MAX_JOBS=1: 串行编译
# 4) MAX_JOBS=1: 串行编译
# 5)AITER_REBUILD=1: 强制重新编译所有模块(默认只编译缺失的模块)
# 5)AITER_REBUILD=1: 强制重新编译所有模块(默认只编译缺失的模块)
# 6) AITER_USE_SCM_VERSION=1: 版本号基于 'git describe --tags --long' 输出(默认为0)
# AITER_USE_SCM_VERSION=0: 版本号使用 setup.py 中的 AITER_FIXED_VERSION,编译时可指定的环境变量;
# install mode:
# install mode:
PYTHONUNBUFFERED
=
1
GPU_ARCHS
=
"gfx936;gfx938"
PREBUILD_KERNELS
=
1
AITER_PREBUILD_LOG_PROGRESS
=
1 python setup.py bdist_wheel
PYTHONUNBUFFERED
=
1
GPU_ARCHS
=
"gfx936;gfx938"
PREBUILD_KERNELS
=
1
AITER_PREBUILD_LOG_PROGRESS
=
1 python setup.py bdist_wheel
...
...
setup.py
View file @
f233de81
...
@@ -223,9 +223,13 @@ class ForcePlatlibDistribution(Distribution):
...
@@ -223,9 +223,13 @@ class ForcePlatlibDistribution(Distribution):
return
True
return
True
AITER_FIXED_VERSION
=
os
.
environ
.
get
(
"AITER_FIXED_VERSION"
,
"0.1.2"
)
AITER_USE_SCM_VERSION
=
int
(
os
.
environ
.
get
(
"AITER_USE_SCM_VERSION"
,
0
))
version_kwargs
=
{
"use_scm_version"
:
True
}
if
AITER_USE_SCM_VERSION
else
{
"version"
:
AITER_FIXED_VERSION
}
setup
(
setup
(
name
=
PACKAGE_NAME
,
name
=
PACKAGE_NAME
,
use_scm_version
=
True
,
**
version_kwargs
,
packages
=
[
"aiter_meta"
,
"aiter"
],
packages
=
[
"aiter_meta"
,
"aiter"
],
include_package_data
=
True
,
include_package_data
=
True
,
package_data
=
{
package_data
=
{
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment