Unverified Commit 1a19e9cd authored by vllmellm's avatar vllmellm Committed by GitHub
Browse files

[Bugfix][ROCm]Fix Qwen3-Next-80B-A3B-Thinking inference and optimize...


[Bugfix][ROCm]Fix Qwen3-Next-80B-A3B-Thinking inference and optimize non-standard block size (544) support under rocm_atten (#31380)
Signed-off-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
parent c8ed39b9
...@@ -112,6 +112,7 @@ def test_contexted_kv_attention( ...@@ -112,6 +112,7 @@ def test_contexted_kv_attention(
kv_cache_dtype: str, kv_cache_dtype: str,
device: str, device: str,
op: Callable, op: Callable,
block_size: int = 32,
) -> None: ) -> None:
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89): if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
pytest.skip( pytest.skip(
...@@ -138,7 +139,6 @@ def test_contexted_kv_attention( ...@@ -138,7 +139,6 @@ def test_contexted_kv_attention(
MAX_CTX_LEN = 1024 MAX_CTX_LEN = 1024
BS = 10 BS = 10
cache_size = 640 cache_size = 640
block_size = 32
max_block_per_request = 64 max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
# ensure one sequence in batch is a decode # ensure one sequence in batch is a decode
...@@ -333,6 +333,7 @@ def test_contexted_kv_attention_alibi( ...@@ -333,6 +333,7 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype: str, kv_cache_dtype: str,
device: str, device: str,
op: Callable, op: Callable,
block_size: int = 32,
) -> None: ) -> None:
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89): if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
pytest.skip( pytest.skip(
...@@ -385,7 +386,6 @@ def test_contexted_kv_attention_alibi( ...@@ -385,7 +386,6 @@ def test_contexted_kv_attention_alibi(
MAX_CTX_LEN = 1024 MAX_CTX_LEN = 1024
BS = 10 BS = 10
cache_size = 640 cache_size = 640
block_size = 32
max_block_per_request = 64 max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
...@@ -637,3 +637,34 @@ def test_contexted_kv_attention_alibi_f32( ...@@ -637,3 +637,34 @@ def test_contexted_kv_attention_alibi_f32(
test_contexted_kv_attention_alibi( test_contexted_kv_attention_alibi(
num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op
) )
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_qwen3_nonstandard_block_size(
head_size: int,
dtype: torch.dtype,
device: str,
op: Callable,
) -> None:
"""
A separate test function specifically added
for Qwen3-Next-80B (Block Size 544).
"""
if not current_platform.is_rocm():
pytest.skip("544 block size optimization is only for ROCm.")
test_contexted_kv_attention(
num_heads=64,
num_queries_per_kv=1,
head_size=head_size,
block_size=544,
sliding_window=0,
dtype=dtype,
kv_cache_dtype="auto",
device=device,
op=op,
)
...@@ -46,6 +46,7 @@ def kernel_paged_attention_2d( ...@@ -46,6 +46,7 @@ def kernel_paged_attention_2d(
output_stride_0: tl.int64, # int output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int BLOCK_SIZE: tl.constexpr, # int
PHYSICAL_BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool USE_ALIBI_SLOPES: tl.constexpr, # bool
...@@ -104,14 +105,15 @@ def kernel_paged_attention_2d( ...@@ -104,14 +105,15 @@ def kernel_paged_attention_2d(
if not USE_SINKS: if not USE_SINKS:
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
L = tl.zeros([num_queries_per_kv_padded], dtype=tl.float32)
else: else:
M = tl.load( M = tl.load(
sink_ptr + query_head_idx, sink_ptr + query_head_idx,
mask=head_mask, mask=head_mask,
other=float("-inf"), other=float("-inf"),
).to(dtype=tl.float32) ).to(dtype=tl.float32)
L = tl.where(float("-inf") < M, 1.0, 0.0)
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32) acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32)
# sequence len for this particular sequence # sequence len for this particular sequence
...@@ -125,30 +127,45 @@ def kernel_paged_attention_2d( ...@@ -125,30 +127,45 @@ def kernel_paged_attention_2d(
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
# iterate through tiles
for j in range(0, num_blocks):
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
offs_n = tl.arange(0, BLOCK_SIZE) offs_n = tl.arange(0, BLOCK_SIZE)
offs_d = tl.arange(0, HEAD_SIZE_PADDED) offs_d = tl.arange(0, HEAD_SIZE_PADDED)
# iterate through tiles
v_offset = ( for j in range(0, num_blocks):
physical_block_idx * stride_v_cache_0 start_n = j * BLOCK_SIZE
+ kv_head_idx * stride_v_cache_1 # Calculate the logical location within a non-standard physical block,
+ offs_d[None, :] * stride_v_cache_2 # such as 544 in Qwen/Qwen3-Next-80B-A3B-Thinking.
+ offs_n[:, None] * stride_v_cache_3 # Supports non-contiguous mapping
) # from logical blocks to physical blocks
abs_token_idx = start_n + offs_n
l_block_idx = abs_token_idx // PHYSICAL_BLOCK_SIZE
# Vectorized loading of physical block IDs
p_block_idx = tl.load(block_tables_ptr + block_table_offset + l_block_idx)
internal_offsets = abs_token_idx % PHYSICAL_BLOCK_SIZE
# 5D addressing logic of K
k_offset = ( k_offset = (
physical_block_idx * stride_k_cache_0 p_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_1 + kv_head_idx * stride_k_cache_1
+ (offs_d[:, None] // x) * stride_k_cache_2 + (offs_d[:, None] // x) * stride_k_cache_2
+ offs_n[None, :] * stride_k_cache_3 + internal_offsets[None, :] * stride_k_cache_3
+ (offs_d[:, None] % x) * stride_k_cache_4 + (offs_d[:, None] % x) * stride_k_cache_4
) )
# 4D addressing logic of V (Slot is innermost)
v_offset = (
p_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_1
+ offs_d[None, :] * stride_v_cache_2
+ internal_offsets[:, None] * stride_v_cache_3
)
# K : (HEAD_SIZE, BLOCK_SIZE) # K : (HEAD_SIZE, BLOCK_SIZE)
K_load = tl.load(key_cache_ptr + k_offset, mask=dim_mask[:, None], other=0.0) K_load = tl.load(
key_cache_ptr + k_offset,
mask=dim_mask[:, None],
other=0.0,
eviction_policy="evict_last",
)
if K_load.dtype.is_fp8(): if K_load.dtype.is_fp8():
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
...@@ -156,7 +173,12 @@ def kernel_paged_attention_2d( ...@@ -156,7 +173,12 @@ def kernel_paged_attention_2d(
K = K_load K = K_load
# V : (BLOCK_SIZE, HEAD_SIZE) # V : (BLOCK_SIZE, HEAD_SIZE)
V_load = tl.load(value_cache_ptr + v_offset, mask=dim_mask[None, :], other=0.0) V_load = tl.load(
value_cache_ptr + v_offset,
mask=dim_mask[None, :],
other=0.0,
eviction_policy="evict_last",
)
if V_load.dtype.is_fp8(): if V_load.dtype.is_fp8():
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
...@@ -167,9 +189,9 @@ def kernel_paged_attention_2d( ...@@ -167,9 +189,9 @@ def kernel_paged_attention_2d(
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32) boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
seq_mask = seq_offset[None, :] < boundary seq_mask = seq_offset[None, :] < boundary
# S : (num_queries_per_kv, BLOCK_SIZE,) # First calculate the dot, then apply the mask.
S = tl.where(head_mask[:, None] & seq_mask, 0.0, float("-inf")).to(tl.float32) qk = scale * tl.dot(Q, K)
S += scale * tl.dot(Q, K) S = tl.where(head_mask[:, None] & seq_mask, qk, float("-inf"))
context_len = seq_len - 1 context_len = seq_len - 1
...@@ -184,13 +206,15 @@ def kernel_paged_attention_2d( ...@@ -184,13 +206,15 @@ def kernel_paged_attention_2d(
m_j = tl.maximum(M, tl.max(S, axis=1)) m_j = tl.maximum(M, tl.max(S, axis=1))
# P : (num_queries_per_kv, BLOCK_SIZE,) # P : (num_queries_per_kv, BLOCK_SIZE,)
P = tl.exp(S - m_j[:, None]) p = tl.exp(S - m_j[:, None])
p = tl.where(m_j[:, None] == float("-inf"), 0.0, p)
# l_j : (num_queries_per_kv,) # l_j : (num_queries_per_kv,)
l_j = tl.sum(P, axis=1) l_j = tl.sum(p, axis=1)
# alpha : (num_queries_per_kv, ) # alpha : (num_queries_per_kv, )
alpha = tl.exp(M - m_j) alpha = tl.exp(M - m_j)
alpha = tl.where(float("-inf") == M, 0.0, alpha)
# acc : (num_queries_per_kv, BLOCK_SIZE,) # acc : (num_queries_per_kv, BLOCK_SIZE,)
acc = acc * alpha[:, None] acc = acc * alpha[:, None]
...@@ -200,10 +224,10 @@ def kernel_paged_attention_2d( ...@@ -200,10 +224,10 @@ def kernel_paged_attention_2d(
M = m_j M = m_j
# acc : (num_queries_per_kv, BLOCK_SIZE,) # acc : (num_queries_per_kv, BLOCK_SIZE,)
acc += tl.dot(P.to(V.dtype), V) acc += tl.dot(p.to(V.dtype), V)
# epilogue # epilogue
acc = acc / L[:, None] acc = acc / (L[:, None] + 1e-10)
if USE_FP8: if USE_FP8:
acc = acc * tl.load(out_scale_inv) acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX) acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
...@@ -241,9 +265,10 @@ def chunked_prefill_paged_decode( ...@@ -241,9 +265,10 @@ def chunked_prefill_paged_decode(
output_scale=None, output_scale=None,
# Optional tensor for sinks # Optional tensor for sinks
sinks=None, sinks=None,
is_block_table_ptr: bool = False,
): ):
if sm_scale is None: if sm_scale is None:
sm_scale = 1.0 / (query.shape[1] ** 0.5) sm_scale = 1.0 / (query.shape[2] ** 0.5)
use_alibi_slopes = alibi_slopes is not None use_alibi_slopes = alibi_slopes is not None
...@@ -315,6 +340,16 @@ def chunked_prefill_paged_decode( ...@@ -315,6 +340,16 @@ def chunked_prefill_paged_decode(
alibi_slopes, alibi_slopes,
sinks, sinks,
) )
# Triton is only forced when encountering a non-standard block
# like Qwen3 with a size of 544.
# 1. Check if block_size is a power of 2 (16, 32, 64...)
# 2. If it's a power of 2, we trust the vLLM's native use_custom decision.
# 3. If it's not a power of 2 (such as Qwen3's 544),
# then our Triton path is forced.
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
if not is_pow2:
use_custom = False
if use_custom: if use_custom:
_PARTITION_SIZE_ROCM = 256 _PARTITION_SIZE_ROCM = 256
max_num_partitions = ( max_num_partitions = (
...@@ -356,6 +391,25 @@ def chunked_prefill_paged_decode( ...@@ -356,6 +391,25 @@ def chunked_prefill_paged_decode(
fp8_out_scale=output_scale, fp8_out_scale=output_scale,
) )
else: else:
real_block_size = value_cache.shape[3]
# The standard model directly uses the original block_size.
# Non-standard 544 uses 32 to accommodate integer division logic.
TRITON_BLOCK_SIZE = block_size if is_pow2 else 32
if is_block_table_ptr:
# Using the physical base address of tensors
kv_element_size = key_cache.element_size()
block_byte_stride = key_cache.stride(0) * kv_element_size
# Get the starting physical address of the KV Cache
base_addr = key_cache.data_ptr()
# Normalization: Directly calculate the block offset
# of the pointer relative to the base address
processed_block_table = ((block_table - base_addr) // block_byte_stride).to(
torch.int32
)
else:
processed_block_table = block_table.to(torch.int32)
kernel_paged_attention_2d[ kernel_paged_attention_2d[
( (
num_seqs, num_seqs,
...@@ -367,7 +421,7 @@ def chunked_prefill_paged_decode( ...@@ -367,7 +421,7 @@ def chunked_prefill_paged_decode(
key_cache_ptr=key_cache, key_cache_ptr=key_cache,
value_cache_ptr=value_cache, value_cache_ptr=value_cache,
sink_ptr=sinks, sink_ptr=sinks,
block_tables_ptr=block_table, block_tables_ptr=processed_block_table,
seq_lens_ptr=seq_lens, seq_lens_ptr=seq_lens,
alibi_slopes_ptr=alibi_slopes, alibi_slopes_ptr=alibi_slopes,
scale=sm_scale, scale=sm_scale,
...@@ -377,12 +431,13 @@ def chunked_prefill_paged_decode( ...@@ -377,12 +431,13 @@ def chunked_prefill_paged_decode(
num_query_heads=num_query_heads, num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv, num_queries_per_kv=num_queries_per_kv,
num_queries_per_kv_padded=num_queries_per_kv_padded, num_queries_per_kv_padded=num_queries_per_kv_padded,
block_table_stride=block_table.stride(0), block_table_stride=processed_block_table.stride(0),
query_stride_0=query.stride(0), query_stride_0=query.stride(0),
query_stride_1=query.stride(1), query_stride_1=query.stride(1),
output_stride_0=output.stride(0), output_stride_0=output.stride(0),
output_stride_1=output.stride(1), output_stride_1=output.stride(1),
BLOCK_SIZE=block_size, BLOCK_SIZE=TRITON_BLOCK_SIZE,
PHYSICAL_BLOCK_SIZE=real_block_size,
HEAD_SIZE=head_size, HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes, USE_ALIBI_SLOPES=use_alibi_slopes,
......
...@@ -79,6 +79,7 @@ def _fwd_kernel( ...@@ -79,6 +79,7 @@ def _fwd_kernel(
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_DMODEL_PADDED: tl.constexpr, BLOCK_DMODEL_PADDED: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
PHYSICAL_BLOCK_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
SLIDING_WINDOW: tl.constexpr, SLIDING_WINDOW: tl.constexpr,
num_unroll_cache: tl.constexpr, num_unroll_cache: tl.constexpr,
...@@ -139,42 +140,52 @@ def _fwd_kernel( ...@@ -139,42 +140,52 @@ def _fwd_kernel(
# initialize pointer to m and l # initialize pointer to m and l
if not USE_SINKS: if not USE_SINKS:
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
else: else:
m_i = tl.load( m_i = tl.load(
sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64), sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64),
mask=(offs_m < cur_batch_query_len), mask=(offs_m < cur_batch_query_len),
other=float("-inf"), other=float("-inf"),
).to(dtype=tl.float32) ).to(dtype=tl.float32)
l_i = tl.where(m_i > float("-inf"), 1.0, 0.0)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
# compute query against context (no causal mask here) # compute query against context (no causal mask here)
for start_n in tl.range( for start_n in tl.range(
0, cur_batch_ctx_len, BLOCK_SIZE, loop_unroll_factor=num_unroll_cache 0, cur_batch_ctx_len, BLOCK_SIZE, loop_unroll_factor=num_unroll_cache
): ):
start_n = tl.multiple_of(start_n, BLOCK_SIZE) # Under a block size of 544 (Qwen/Qwen3-Next-80B-A3B-Thinking),
# -- compute qk ---- # replace one physical block every 17 32-Tile blocks
# Calculate the logical block index of each of the 32 tokens
# in the current Tile (handling cross-block cases).
token_indices = start_n + offs_bs_n
bn_logical_indices = token_indices // PHYSICAL_BLOCK_SIZE
# 2. Vectorized loading of physical block IDs from B_Loc
bn = tl.load( bn = tl.load(
B_Loc B_Loc + cur_batch * stride_b_loc_b + bn_logical_indices * stride_b_loc_s
+ cur_batch * stride_b_loc_b
+ (start_n // BLOCK_SIZE) * stride_b_loc_s
).to(tl.int64) ).to(tl.int64)
# [D,BLOCK_SIZE]
# 3. Calculate the exact offset of
# each token within its physical block.
internal_offsets = token_indices % PHYSICAL_BLOCK_SIZE
# Addressing of K (5D)
off_k = ( off_k = (
bn[None, :] * stride_k_cache_bs bn[None, :] * stride_k_cache_bs
+ cur_kv_head * stride_k_cache_h + cur_kv_head * stride_k_cache_h
+ (offs_d[:, None] // x) * stride_k_cache_d + (offs_d[:, None] // x) * stride_k_cache_d
+ ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl + internal_offsets[None, :] * stride_k_cache_bl
+ (offs_d[:, None] % x) * stride_k_cache_x + (offs_d[:, None] % x) * stride_k_cache_x
) )
# [BLOCK_SIZE,D] # Addressing of V (4D)
off_v = ( off_v = (
bn[:, None] * stride_v_cache_bs bn[:, None] * stride_v_cache_bs
+ cur_kv_head * stride_v_cache_h + cur_kv_head * stride_v_cache_h
+ offs_d[None, :] * stride_v_cache_d + offs_d[None, :] * stride_v_cache_d
+ offs_bs_n[:, None] * stride_v_cache_bl + internal_offsets[:, None] * stride_v_cache_bl
) )
if ( if (
...@@ -195,12 +206,12 @@ def _fwd_kernel( ...@@ -195,12 +206,12 @@ def _fwd_kernel(
else: else:
k = k_load k = k_load
qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] # qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) qk = sm_scale * tl.dot(q, k, input_precision=IN_PRECISION)
qk = tl.where( qk = tl.where(
(start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf") (start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")
) )
qk *= sm_scale # qk *= sm_scale
if SLIDING_WINDOW > 0: if SLIDING_WINDOW > 0:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of # (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence # Q entries in sequence
...@@ -217,14 +228,16 @@ def _fwd_kernel( ...@@ -217,14 +228,16 @@ def _fwd_kernel(
(cur_batch_ctx_len + offs_m[:, None]) - (start_n + offs_bs_n[None, :]) (cur_batch_ctx_len + offs_m[:, None]) - (start_n + offs_bs_n[None, :])
< SLIDING_WINDOW, < SLIDING_WINDOW,
qk, qk,
-10000, float("-inf"),
) )
# compute running maximum # compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp(qk - m_ij[:, None]) p = tl.exp(qk - m_ij[:, None])
p = tl.where(m_ij[:, None] == float("-inf"), 0.0, p)
l_ij = tl.sum(p, axis=1) l_ij = tl.sum(p, axis=1)
alpha = tl.exp(m_i - m_ij) alpha = tl.exp(m_i - m_ij)
alpha = tl.where(m_i == float("-inf"), 0.0, alpha)
acc = acc * alpha[:, None] acc = acc * alpha[:, None]
# update acc # update acc
...@@ -293,14 +306,17 @@ def _fwd_kernel( ...@@ -293,14 +306,17 @@ def _fwd_kernel(
qk = tl.where( qk = tl.where(
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW, offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
qk, qk,
-10000, float("-inf"),
) )
# compute running maximum # compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp(qk - m_ij[:, None]) p = tl.exp(qk - m_ij[:, None])
p = tl.where(m_ij[:, None] == float("-inf"), 0.0, p)
l_ij = tl.sum(p, axis=1) l_ij = tl.sum(p, axis=1)
alpha = tl.exp(m_i - m_ij) alpha = tl.exp(m_i - m_ij)
# To prevent NaN from appearing in the first round
alpha = tl.where(m_i == float("-inf"), 0.0, alpha)
acc = acc * alpha[:, None] acc = acc * alpha[:, None]
# update acc # update acc
...@@ -317,7 +333,7 @@ def _fwd_kernel( ...@@ -317,7 +333,7 @@ def _fwd_kernel(
l_i = l_i * alpha + l_ij l_i = l_i * alpha + l_ij
m_i = m_ij m_i = m_ij
acc = acc / l_i[:, None] acc = acc / (l_i[:, None] + 1e-10)
# initialize pointers to output # initialize pointers to output
off_o = ( off_o = (
...@@ -637,6 +653,7 @@ def context_attention_fwd( ...@@ -637,6 +653,7 @@ def context_attention_fwd(
skip_decode=False, skip_decode=False,
fp8_out_scale=None, fp8_out_scale=None,
sinks=None, sinks=None,
is_block_table_ptr: bool = False,
): ):
q_dtype_is_f32 = q.dtype is torch.float32 q_dtype_is_f32 = q.dtype is torch.float32
...@@ -689,6 +706,19 @@ def context_attention_fwd( ...@@ -689,6 +706,19 @@ def context_attention_fwd(
if sliding_window is None or sliding_window <= 0: if sliding_window is None or sliding_window <= 0:
sliding_window = 0 sliding_window = 0
if is_block_table_ptr:
kv_element_size = k_cache.element_size()
block_byte_stride = k_cache.stride(0) * kv_element_size
# The physical starting point of the obtained KV Cache Pool
base_addr = k_cache.data_ptr()
mask = b_loc > 0
processed_b_loc = torch.where(
mask, (b_loc - base_addr) // block_byte_stride, b_loc
).to(torch.int32)
else:
processed_b_loc = b_loc.to(torch.int32)
if alibi_slopes is not None: if alibi_slopes is not None:
assert sinks is None, "Sinks arg is not supported with alibi" assert sinks is None, "Sinks arg is not supported with alibi"
assert fp8_out_scale is None, "FP8 output not supported with alibi" assert fp8_out_scale is None, "FP8 output not supported with alibi"
...@@ -752,7 +782,24 @@ def context_attention_fwd( ...@@ -752,7 +782,24 @@ def context_attention_fwd(
max_seq_len = 0 if max_seq_len is None else max_seq_len max_seq_len = 0 if max_seq_len is None else max_seq_len
extra_kargs = {} extra_kargs = {}
if current_platform.is_rocm(): if current_platform.is_rocm():
extra_kargs = {"kpack": 1, "waves_per_eu": 2} extra_kargs = {}
real_block_size = v_cache.shape[3]
is_pow2 = real_block_size > 0 and (real_block_size & (real_block_size - 1) == 0)
# For standard models involving powers of 2,
# follow the original logic (Llama 128/64)
# For non-standard models (Qwen3-next block_size 544), set to 32.
if is_pow2:
BLOCK_M = 128
BLOCK_N = 64
else:
BLOCK_M = 32
BLOCK_N = 32
# TRITON_BLOCK_SIZE is kept at 32 to ensure
# correct alignment logic when the kernel handles
# non-standard sizes (such as 544).
TRITON_BLOCK_SIZE = 32
grid_fn = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) grid_fn = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"]))
_fwd_kernel[grid_fn]( _fwd_kernel[grid_fn](
...@@ -762,7 +809,7 @@ def context_attention_fwd( ...@@ -762,7 +809,7 @@ def context_attention_fwd(
k_cache, k_cache,
v_cache, v_cache,
sinks, sinks,
b_loc, processed_b_loc,
sm_scale, sm_scale,
k_scale, k_scale,
v_scale, v_scale,
...@@ -771,8 +818,8 @@ def context_attention_fwd( ...@@ -771,8 +818,8 @@ def context_attention_fwd(
b_seq_len, b_seq_len,
k_cache.shape[4], k_cache.shape[4],
o, o,
b_loc.stride(0), processed_b_loc.stride(0),
b_loc.stride(1), processed_b_loc.stride(1),
q.stride(0), q.stride(0),
q.stride(1), q.stride(1),
q.stride(2), q.stride(2),
...@@ -785,16 +832,17 @@ def context_attention_fwd( ...@@ -785,16 +832,17 @@ def context_attention_fwd(
o.stride(0), o.stride(0),
o.stride(1), o.stride(1),
o.stride(2), o.stride(2),
k_cache.stride(0), stride_k_cache_bs=k_cache.stride(0),
k_cache.stride(1), stride_k_cache_h=k_cache.stride(1),
k_cache.stride(2), stride_k_cache_d=k_cache.stride(2),
k_cache.stride(3), stride_k_cache_bl=k_cache.stride(3),
k_cache.stride(4), # [num_blocks, num_kv_heads, head_size/x, block_size, x] stride_k_cache_x=k_cache.stride(4),
v_cache.stride(0), stride_v_cache_bs=v_cache.stride(0),
v_cache.stride(1), stride_v_cache_h=v_cache.stride(1),
v_cache.stride(2), stride_v_cache_d=v_cache.stride(2),
v_cache.stride(3), # [num_blocks, num_kv_heads, head_size, block_size] stride_v_cache_bl=v_cache.stride(3),
BLOCK_SIZE=v_cache.shape[3], BLOCK_SIZE=TRITON_BLOCK_SIZE,
PHYSICAL_BLOCK_SIZE=real_block_size,
num_queries_per_kv=num_queries_per_kv, num_queries_per_kv=num_queries_per_kv,
IN_PRECISION=IN_PRECISION, IN_PRECISION=IN_PRECISION,
BLOCK_DMODEL=Lk, BLOCK_DMODEL=Lk,
...@@ -802,8 +850,8 @@ def context_attention_fwd( ...@@ -802,8 +850,8 @@ def context_attention_fwd(
SLIDING_WINDOW=sliding_window, SLIDING_WINDOW=sliding_window,
SKIP_DECODE=skip_decode, SKIP_DECODE=skip_decode,
USE_FP8=fp8_out_scale is not None, USE_FP8=fp8_out_scale is not None,
BLOCK_M=128, BLOCK_M=BLOCK_M,
BLOCK_N=64, BLOCK_N=BLOCK_N,
num_unroll_cache=4, num_unroll_cache=4,
num_unroll_request=1, num_unroll_request=1,
num_warps=4, num_warps=4,
......
...@@ -20,10 +20,15 @@ def reshape_and_cache_kernel_flash( ...@@ -20,10 +20,15 @@ def reshape_and_cache_kernel_flash(
key_stride: tl.int64, key_stride: tl.int64,
value_stride: tl.int64, value_stride: tl.int64,
block_stride: tl.int64, block_stride: tl.int64,
head_stride: tl.int64,
dim_stride_k: tl.int64,
dim_stride_v: tl.int64,
page_stride: tl.int64, page_stride: tl.int64,
num_heads: tl.constexpr, num_heads: tl.constexpr,
head_size: tl.constexpr, head_size: tl.constexpr,
block_size: tl.constexpr, block_size: tl.constexpr,
x: tl.constexpr,
USE_HEAD_MAJOR_LAYOUT: tl.constexpr,
# FP8 flags # FP8 flags
FP8_KV_CACHE: tl.constexpr, FP8_KV_CACHE: tl.constexpr,
# tune parameters # tune parameters
...@@ -35,17 +40,38 @@ def reshape_and_cache_kernel_flash( ...@@ -35,17 +40,38 @@ def reshape_and_cache_kernel_flash(
# Padding token that should be ignored. # Padding token that should be ignored.
return return
tile_i = tl.program_id(axis=1)
tile_offs = tl.arange(0, TILE_SIZE)
tile_pos = tile_i * TILE_SIZE + tile_offs
block_idx = slot_idx // block_size block_idx = slot_idx // block_size
block_offset = slot_idx % block_size block_offset = slot_idx % block_size
tile_i = tl.program_id(axis=1)
tile_offs = tl.arange(0, TILE_SIZE)
tile_pos = tile_i * TILE_SIZE + tile_offs
src_key_idx = token_idx * key_stride src_key_idx = token_idx * key_stride
src_value_idx = token_idx * value_stride src_value_idx = token_idx * value_stride
tgt_idx = block_idx * block_stride + block_offset * page_stride if USE_HEAD_MAJOR_LAYOUT:
# Decompose the tile index back into head and dim coordinates.
cur_head = tile_pos // head_size
cur_dim = tile_pos % head_size
# Value addressing (4D): [Block, Head, Dim, Slot]
tgt_idx_v = (
block_idx * block_stride
+ cur_head * head_stride
+ cur_dim * dim_stride_v
+ block_offset * 1
)
# Key addressing (5D): [Block, Head, Dim//8, Slot, 8]
tgt_idx_k = (
block_idx * block_stride
+ cur_head * head_stride
+ (cur_dim // x) * dim_stride_k
+ block_offset * x
+ (cur_dim % x)
)
else:
tgt_base = block_idx * block_stride + block_offset * page_stride
tgt_idx_k = tgt_base + tile_pos
tgt_idx_v = tgt_base + tile_pos
# [TILE_SIZE] # [TILE_SIZE]
key_load = tl.load( key_load = tl.load(
...@@ -73,12 +99,12 @@ def reshape_and_cache_kernel_flash( ...@@ -73,12 +99,12 @@ def reshape_and_cache_kernel_flash(
value_tile = value_load value_tile = value_load
tl.store( tl.store(
key_cache_ptr + tgt_idx + tile_pos, key_cache_ptr + tgt_idx_k,
key_tile, key_tile,
mask=tile_pos < (num_heads * head_size), mask=tile_pos < (num_heads * head_size),
) )
tl.store( tl.store(
value_cache_ptr + tgt_idx + tile_pos, value_cache_ptr + tgt_idx_v,
value_tile, value_tile,
mask=tile_pos < (num_heads * head_size), mask=tile_pos < (num_heads * head_size),
) )
...@@ -99,17 +125,26 @@ def triton_reshape_and_cache_flash( ...@@ -99,17 +125,26 @@ def triton_reshape_and_cache_flash(
): ):
num_heads = key.shape[1] num_heads = key.shape[1]
head_size = key.shape[2] head_size = key.shape[2]
use_head_major_layout = key_cache.ndim == 5
if use_head_major_layout:
block_size = key_cache.shape[3]
x = key_cache.shape[4]
head_stride = key_cache.stride(1)
dim_stride_k = key_cache.stride(2)
dim_stride_v = value_cache.stride(2)
else:
block_size = key_cache.shape[1] block_size = key_cache.shape[1]
x = 1
dim_stride_k = 0
dim_stride_v = 0
head_stride = key_cache.stride()[2]
n = num_heads * head_size n = num_heads * head_size
key_stride = key.stride()[0] key_stride = key.stride()[0]
value_stride = value.stride()[0] value_stride = value.stride()[0]
block_stride = key_cache.stride()[0] block_stride = key_cache.stride()[0]
page_stride = key_cache.stride()[1] page_stride = key_cache.stride()[1]
head_stride = key_cache.stride()[2]
assert head_stride == head_size, "only continous heads are supported"
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), ( assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
) )
...@@ -171,10 +206,15 @@ def triton_reshape_and_cache_flash( ...@@ -171,10 +206,15 @@ def triton_reshape_and_cache_flash(
key_stride=key_stride, key_stride=key_stride,
value_stride=value_stride, value_stride=value_stride,
block_stride=block_stride, block_stride=block_stride,
head_stride=head_stride,
dim_stride_k=dim_stride_k,
dim_stride_v=dim_stride_v,
page_stride=page_stride, page_stride=page_stride,
num_heads=num_heads, num_heads=num_heads,
head_size=head_size, head_size=head_size,
block_size=block_size, block_size=block_size,
x=x,
USE_HEAD_MAJOR_LAYOUT=use_head_major_layout,
# FP8 flags # FP8 flags
FP8_KV_CACHE=FP8_KV_CACHE, FP8_KV_CACHE=FP8_KV_CACHE,
# autotune parameters # autotune parameters
......
...@@ -15,6 +15,9 @@ from vllm.attention.backends.abstract import ( ...@@ -15,6 +15,9 @@ from vllm.attention.backends.abstract import (
) )
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
...@@ -321,6 +324,15 @@ class RocmAttentionImpl(AttentionImpl): ...@@ -321,6 +324,15 @@ class RocmAttentionImpl(AttentionImpl):
if self.kv_sharing_target_layer_name is None: if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer. # Skip this if sharing KV cache with an earlier attention layer.
# Get the actual block_size from value_cache
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
# Determine if it is a power of 2
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
if is_pow2:
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
PagedAttention.write_to_paged_cache( PagedAttention.write_to_paged_cache(
key, key,
value, value,
...@@ -331,6 +343,19 @@ class RocmAttentionImpl(AttentionImpl): ...@@ -331,6 +343,19 @@ class RocmAttentionImpl(AttentionImpl):
layer._k_scale, layer._k_scale,
layer._v_scale, layer._v_scale,
) )
else:
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
# force using our modified Triton logic
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment