Commit f233de81 authored by Xiaowei.zhang's avatar Xiaowei.zhang
Browse files

[SYNC] Code sync.

parent 1893a1e0
# SPDX-License-Identifier: MIT
import torch
import torch.nn.functional as F
import pytest
from aiter.ops.triton.extend_attention import extend_attention_fwd
def extend_attention_fwd_torch_swa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
qo_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
kv_indices: torch.Tensor,
sliding_window_size: int,
*,
k_scale: float = 1.0,
v_scale: float = 1.0,
sm_scale: float | None = None,
):
"""Reference for causal + sliding-window extend attention (sglang test style).
Runs the heavy matmul/softmax on CPU float32 for numerical stability and to avoid
ROCm aborts on large bf16 einsum after GPU kernels.
v2 与 Triton 一致:``k_scale`` / ``v_scale`` **只作用在 prefix(cache)键位**;
extend 段 logits 与 V 不额外乘这两个标量。
"""
B = qo_indptr.size(0) - 1
_, H_Q, D = q.shape
_, H_KV, _ = k.shape
group_size = H_Q // H_KV
scale = float(sm_scale) if sm_scale is not None else 1.0 / D**0.5
out_dev = o.device
out_dtype = o.dtype
for i in range(B):
q_start = int(qo_indptr[i].item())
q_end = int(qo_indptr[i + 1].item())
kv_start = int(kv_indptr[i].item())
kv_end = int(kv_indptr[i + 1].item())
prefix_indices = kv_indices[kv_start:kv_end]
k_prefix = k_cache[prefix_indices]
v_prefix = v_cache[prefix_indices]
k_extend = k[q_start:q_end]
v_extend = v[q_start:q_end]
q_extend = q[q_start:q_end]
k_full = torch.cat([k_prefix, k_extend], dim=0)
v_full = torch.cat([v_prefix, v_extend], dim=0)
if group_size != 1:
k_full_hq = k_full.repeat_interleave(group_size, dim=1)
v_full_hq = v_full.repeat_interleave(group_size, dim=1)
else:
k_full_hq = k_full
v_full_hq = v_full
prefix_len = k_prefix.size(0)
extend_len = k_extend.size(0)
total_len = prefix_len + extend_len
q_e = q_extend.detach().float().cpu()
k_h = k_full_hq.detach().float().cpu()
v_h = v_full_hq.detach().float().cpu()
pos_keys = torch.arange(total_len)
t = prefix_len + torch.arange(extend_len)
causal_mask = pos_keys.unsqueeze(0) <= t.unsqueeze(1)
if sliding_window_size is not None and sliding_window_size > 0:
start = (t - sliding_window_size).clamp_min(0)
else:
start = torch.zeros_like(t)
window_mask = pos_keys.unsqueeze(0) >= start.unsqueeze(1)
final_mask = causal_mask & window_mask
attn_scores = torch.einsum("qhd,khd->qhk", q_e, k_h) * scale
if k_scale != 1.0:
attn_scores[:, :, :prefix_len] = attn_scores[:, :, :prefix_len] * k_scale
attn_scores = attn_scores.masked_fill(~final_mask.unsqueeze(1), float("-inf"))
attn_weights = F.softmax(attn_scores, dim=-1)
if v_scale != 1.0:
v_prefix = v_h[:prefix_len] * v_scale
v_h_scaled = torch.cat([v_prefix, v_h[prefix_len:]], dim=0)
else:
v_h_scaled = v_h
out_cpu = torch.einsum("qhk,khd->qhd", attn_weights, v_h_scaled)
o[q_start:q_end] = out_cpu.to(device=out_dev, dtype=out_dtype)
def input_helper(
B,
H,
......@@ -269,6 +363,218 @@ def test_op_fwd(
torch.testing.assert_close(ref_out, tri_out, rtol=2e-2, atol=2e-2)
def test_extend_attention_v2_identity_scales_match_v1():
"""v2 with fp32 1.0 scales should match v1 (k_scale/v_scale None)."""
device = "cuda"
dtype = torch.float16
torch.manual_seed(0)
(
q_extend,
k_extend,
v_extend,
k_buffer,
v_buffer,
kv_indptr,
kv_indices,
qo_indptr,
custom_mask,
mask_indptr,
max_len_extend,
) = input_helper(
2, 4, 64, 32, 128, 64, 128, dtype, device, "normal"
)
out_v1 = torch.empty(
(*q_extend.shape[:-1], v_extend.shape[-1]),
dtype=q_extend.dtype,
device=device,
)
out_v2 = torch.empty_like(out_v1)
extend_attention_fwd(
q_extend,
k_extend,
v_extend,
out_v1,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
True,
mask_indptr,
max_len_extend,
sm_scale=None,
logit_cap=0.0,
skip_prefix_custom_mask=True,
config=None,
)
extend_attention_fwd(
q_extend,
k_extend,
v_extend,
out_v2,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
True,
mask_indptr,
max_len_extend,
sm_scale=None,
logit_cap=0.0,
skip_prefix_custom_mask=True,
config=None,
k_scale=1.0,
v_scale=1.0,
sliding_window_size=-1,
sinks=None,
window_kv_offsets=None,
xai_temperature_len=-1,
)
torch.testing.assert_close(out_v1, out_v2, rtol=2e-2, atol=2e-2)
def _build_extend_inputs_swa_style(B, N_CTX, H_Q, H_KV, D, device, dtype):
"""Layout aligned with sglang test_triton_attention_kernels sliding-window setup."""
b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32, device=device)
b_seq_len_extend = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32, device=device)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
b_start_loc = torch.zeros((B,), dtype=torch.int32, device=device)
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=device)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
kv_indices = torch.zeros(
(b_seq_len_prefix.sum().item(),), dtype=torch.int32, device=device
)
for i in range(B):
kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i], device=device
)
total_token_num = torch.sum(b_seq_len).item()
extend_token_num = torch.sum(b_seq_len_extend).item()
k_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device=device
).normal_(mean=0.1, std=0.2)
v_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device=device
).normal_(mean=0.1, std=0.2)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.empty(
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device=device
).normal_(mean=0.1, std=0.2)
b_seq_len_extend = b_seq_len - b_seq_len_prefix
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
return (
q_extend,
k_extend,
v_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
max_len_extend,
)
@pytest.mark.parametrize("window_size", [-1, 32, 127])
def test_extend_attention_v2_sliding_window(window_size):
"""v2 + sliding_window_size vs torch reference (sglang-style construction)."""
torch.manual_seed(42)
device = "cuda"
dtype = torch.bfloat16
B, N_CTX, H_Q, H_KV, D = 4, 512, 8, 8, 128
(
q_extend,
k_extend,
v_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
max_len_extend,
) = _build_extend_inputs_swa_style(B, N_CTX, H_Q, H_KV, D, device, dtype)
extend_token_num = q_extend.shape[0]
o_triton = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
o_torch = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_triton,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask=None,
is_causal=True,
mask_indptr=None,
max_len_extend=max_len_extend,
sm_scale=None,
logit_cap=0.0,
skip_prefix_custom_mask=True,
config=None,
k_scale=1.2,
v_scale=1.2,
sliding_window_size=window_size,
sinks=None,
window_kv_offsets=None,
xai_temperature_len=-1,
)
extend_attention_fwd_torch_swa(
q_extend,
k_extend,
v_extend,
o_torch,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
window_size,
k_scale=1.2,
v_scale=1.2,
sm_scale=None,
)
torch.testing.assert_close(o_triton, o_torch, rtol=2e-2, atol=2e-2)
if __name__ == "__main__":
test_op_fwd(1, 2, 1024, 1024, 256, 0, 256, torch.bfloat16, "normal", False)
test_op_fwd(3, 5, 110, 333, 18, 0, 17, torch.float32, "normal", True)
......
......@@ -24,6 +24,8 @@ echo "### start rebuild aiter..."
# 3)AITER_LOG_MORE=2: Python/JIT 流程、模块触发、参数和更高层日志
# 4) MAX_JOBS=1: 串行编译
# 5)AITER_REBUILD=1: 强制重新编译所有模块(默认只编译缺失的模块)
# 6) AITER_USE_SCM_VERSION=1: 版本号基于 'git describe --tags --long' 输出(默认为0)
# AITER_USE_SCM_VERSION=0: 版本号使用 setup.py 中的 AITER_FIXED_VERSION,编译时可指定的环境变量;
# install mode:
PYTHONUNBUFFERED=1 GPU_ARCHS="gfx936;gfx938" PREBUILD_KERNELS=1 AITER_PREBUILD_LOG_PROGRESS=1 python setup.py bdist_wheel
......
......@@ -223,9 +223,13 @@ class ForcePlatlibDistribution(Distribution):
return True
AITER_FIXED_VERSION = os.environ.get("AITER_FIXED_VERSION", "0.1.2")
AITER_USE_SCM_VERSION = int(os.environ.get("AITER_USE_SCM_VERSION", 0))
version_kwargs = {"use_scm_version": True} if AITER_USE_SCM_VERSION else {"version": AITER_FIXED_VERSION}
setup(
name=PACKAGE_NAME,
use_scm_version=True,
**version_kwargs,
packages=["aiter_meta", "aiter"],
include_package_data=True,
package_data={
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment