Unverified Commit 1a19e9cd authored by vllmellm's avatar vllmellm Committed by GitHub
Browse files

[Bugfix][ROCm]Fix Qwen3-Next-80B-A3B-Thinking inference and optimize...


[Bugfix][ROCm]Fix Qwen3-Next-80B-A3B-Thinking inference and optimize non-standard block size (544) support under rocm_atten (#31380)
Signed-off-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
parent c8ed39b9
......@@ -112,6 +112,7 @@ def test_contexted_kv_attention(
kv_cache_dtype: str,
device: str,
op: Callable,
block_size: int = 32,
) -> None:
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
pytest.skip(
......@@ -138,7 +139,6 @@ def test_contexted_kv_attention(
MAX_CTX_LEN = 1024
BS = 10
cache_size = 640
block_size = 32
max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
# ensure one sequence in batch is a decode
......@@ -333,6 +333,7 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype: str,
device: str,
op: Callable,
block_size: int = 32,
) -> None:
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
pytest.skip(
......@@ -385,7 +386,6 @@ def test_contexted_kv_attention_alibi(
MAX_CTX_LEN = 1024
BS = 10
cache_size = 640
block_size = 32
max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
......@@ -637,3 +637,34 @@ def test_contexted_kv_attention_alibi_f32(
test_contexted_kv_attention_alibi(
num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op
)
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_qwen3_nonstandard_block_size(
head_size: int,
dtype: torch.dtype,
device: str,
op: Callable,
) -> None:
"""
A separate test function specifically added
for Qwen3-Next-80B (Block Size 544).
"""
if not current_platform.is_rocm():
pytest.skip("544 block size optimization is only for ROCm.")
test_contexted_kv_attention(
num_heads=64,
num_queries_per_kv=1,
head_size=head_size,
block_size=544,
sliding_window=0,
dtype=dtype,
kv_cache_dtype="auto",
device=device,
op=op,
)
......@@ -46,6 +46,7 @@ def kernel_paged_attention_2d(
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int
PHYSICAL_BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
......@@ -104,14 +105,15 @@ def kernel_paged_attention_2d(
if not USE_SINKS:
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
L = tl.zeros([num_queries_per_kv_padded], dtype=tl.float32)
else:
M = tl.load(
sink_ptr + query_head_idx,
mask=head_mask,
other=float("-inf"),
).to(dtype=tl.float32)
L = tl.where(float("-inf") < M, 1.0, 0.0)
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32)
# sequence len for this particular sequence
......@@ -125,30 +127,45 @@ def kernel_paged_attention_2d(
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
offs_n = tl.arange(0, BLOCK_SIZE)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
# iterate through tiles
for j in range(0, num_blocks):
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
offs_n = tl.arange(0, BLOCK_SIZE)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
v_offset = (
physical_block_idx * stride_v_cache_0
+ kv_head_idx * stride_v_cache_1
+ offs_d[None, :] * stride_v_cache_2
+ offs_n[:, None] * stride_v_cache_3
)
start_n = j * BLOCK_SIZE
# Calculate the logical location within a non-standard physical block,
# such as 544 in Qwen/Qwen3-Next-80B-A3B-Thinking.
# Supports non-contiguous mapping
# from logical blocks to physical blocks
abs_token_idx = start_n + offs_n
l_block_idx = abs_token_idx // PHYSICAL_BLOCK_SIZE
# Vectorized loading of physical block IDs
p_block_idx = tl.load(block_tables_ptr + block_table_offset + l_block_idx)
internal_offsets = abs_token_idx % PHYSICAL_BLOCK_SIZE
# 5D addressing logic of K
k_offset = (
physical_block_idx * stride_k_cache_0
p_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_1
+ (offs_d[:, None] // x) * stride_k_cache_2
+ offs_n[None, :] * stride_k_cache_3
+ internal_offsets[None, :] * stride_k_cache_3
+ (offs_d[:, None] % x) * stride_k_cache_4
)
# 4D addressing logic of V (Slot is innermost)
v_offset = (
p_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_1
+ offs_d[None, :] * stride_v_cache_2
+ internal_offsets[:, None] * stride_v_cache_3
)
# K : (HEAD_SIZE, BLOCK_SIZE)
K_load = tl.load(key_cache_ptr + k_offset, mask=dim_mask[:, None], other=0.0)
K_load = tl.load(
key_cache_ptr + k_offset,
mask=dim_mask[:, None],
other=0.0,
eviction_policy="evict_last",
)
if K_load.dtype.is_fp8():
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
......@@ -156,7 +173,12 @@ def kernel_paged_attention_2d(
K = K_load
# V : (BLOCK_SIZE, HEAD_SIZE)
V_load = tl.load(value_cache_ptr + v_offset, mask=dim_mask[None, :], other=0.0)
V_load = tl.load(
value_cache_ptr + v_offset,
mask=dim_mask[None, :],
other=0.0,
eviction_policy="evict_last",
)
if V_load.dtype.is_fp8():
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
......@@ -167,9 +189,9 @@ def kernel_paged_attention_2d(
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
seq_mask = seq_offset[None, :] < boundary
# S : (num_queries_per_kv, BLOCK_SIZE,)
S = tl.where(head_mask[:, None] & seq_mask, 0.0, float("-inf")).to(tl.float32)
S += scale * tl.dot(Q, K)
# First calculate the dot, then apply the mask.
qk = scale * tl.dot(Q, K)
S = tl.where(head_mask[:, None] & seq_mask, qk, float("-inf"))
context_len = seq_len - 1
......@@ -184,13 +206,15 @@ def kernel_paged_attention_2d(
m_j = tl.maximum(M, tl.max(S, axis=1))
# P : (num_queries_per_kv, BLOCK_SIZE,)
P = tl.exp(S - m_j[:, None])
p = tl.exp(S - m_j[:, None])
p = tl.where(m_j[:, None] == float("-inf"), 0.0, p)
# l_j : (num_queries_per_kv,)
l_j = tl.sum(P, axis=1)
l_j = tl.sum(p, axis=1)
# alpha : (num_queries_per_kv, )
alpha = tl.exp(M - m_j)
alpha = tl.where(float("-inf") == M, 0.0, alpha)
# acc : (num_queries_per_kv, BLOCK_SIZE,)
acc = acc * alpha[:, None]
......@@ -200,10 +224,10 @@ def kernel_paged_attention_2d(
M = m_j
# acc : (num_queries_per_kv, BLOCK_SIZE,)
acc += tl.dot(P.to(V.dtype), V)
acc += tl.dot(p.to(V.dtype), V)
# epilogue
acc = acc / L[:, None]
acc = acc / (L[:, None] + 1e-10)
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
......@@ -241,9 +265,10 @@ def chunked_prefill_paged_decode(
output_scale=None,
# Optional tensor for sinks
sinks=None,
is_block_table_ptr: bool = False,
):
if sm_scale is None:
sm_scale = 1.0 / (query.shape[1] ** 0.5)
sm_scale = 1.0 / (query.shape[2] ** 0.5)
use_alibi_slopes = alibi_slopes is not None
......@@ -315,6 +340,16 @@ def chunked_prefill_paged_decode(
alibi_slopes,
sinks,
)
# Triton is only forced when encountering a non-standard block
# like Qwen3 with a size of 544.
# 1. Check if block_size is a power of 2 (16, 32, 64...)
# 2. If it's a power of 2, we trust the vLLM's native use_custom decision.
# 3. If it's not a power of 2 (such as Qwen3's 544),
# then our Triton path is forced.
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
if not is_pow2:
use_custom = False
if use_custom:
_PARTITION_SIZE_ROCM = 256
max_num_partitions = (
......@@ -356,6 +391,25 @@ def chunked_prefill_paged_decode(
fp8_out_scale=output_scale,
)
else:
real_block_size = value_cache.shape[3]
# The standard model directly uses the original block_size.
# Non-standard 544 uses 32 to accommodate integer division logic.
TRITON_BLOCK_SIZE = block_size if is_pow2 else 32
if is_block_table_ptr:
# Using the physical base address of tensors
kv_element_size = key_cache.element_size()
block_byte_stride = key_cache.stride(0) * kv_element_size
# Get the starting physical address of the KV Cache
base_addr = key_cache.data_ptr()
# Normalization: Directly calculate the block offset
# of the pointer relative to the base address
processed_block_table = ((block_table - base_addr) // block_byte_stride).to(
torch.int32
)
else:
processed_block_table = block_table.to(torch.int32)
kernel_paged_attention_2d[
(
num_seqs,
......@@ -367,7 +421,7 @@ def chunked_prefill_paged_decode(
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
sink_ptr=sinks,
block_tables_ptr=block_table,
block_tables_ptr=processed_block_table,
seq_lens_ptr=seq_lens,
alibi_slopes_ptr=alibi_slopes,
scale=sm_scale,
......@@ -377,12 +431,13 @@ def chunked_prefill_paged_decode(
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
num_queries_per_kv_padded=num_queries_per_kv_padded,
block_table_stride=block_table.stride(0),
block_table_stride=processed_block_table.stride(0),
query_stride_0=query.stride(0),
query_stride_1=query.stride(1),
output_stride_0=output.stride(0),
output_stride_1=output.stride(1),
BLOCK_SIZE=block_size,
BLOCK_SIZE=TRITON_BLOCK_SIZE,
PHYSICAL_BLOCK_SIZE=real_block_size,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
......
......@@ -79,6 +79,7 @@ def _fwd_kernel(
BLOCK_DMODEL: tl.constexpr,
BLOCK_DMODEL_PADDED: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
PHYSICAL_BLOCK_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr,
SLIDING_WINDOW: tl.constexpr,
num_unroll_cache: tl.constexpr,
......@@ -139,42 +140,52 @@ def _fwd_kernel(
# initialize pointer to m and l
if not USE_SINKS:
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
else:
m_i = tl.load(
sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64),
mask=(offs_m < cur_batch_query_len),
other=float("-inf"),
).to(dtype=tl.float32)
l_i = tl.where(m_i > float("-inf"), 1.0, 0.0)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
# compute query against context (no causal mask here)
for start_n in tl.range(
0, cur_batch_ctx_len, BLOCK_SIZE, loop_unroll_factor=num_unroll_cache
):
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
# -- compute qk ----
# Under a block size of 544 (Qwen/Qwen3-Next-80B-A3B-Thinking),
# replace one physical block every 17 32-Tile blocks
# Calculate the logical block index of each of the 32 tokens
# in the current Tile (handling cross-block cases).
token_indices = start_n + offs_bs_n
bn_logical_indices = token_indices // PHYSICAL_BLOCK_SIZE
# 2. Vectorized loading of physical block IDs from B_Loc
bn = tl.load(
B_Loc
+ cur_batch * stride_b_loc_b
+ (start_n // BLOCK_SIZE) * stride_b_loc_s
B_Loc + cur_batch * stride_b_loc_b + bn_logical_indices * stride_b_loc_s
).to(tl.int64)
# [D,BLOCK_SIZE]
# 3. Calculate the exact offset of
# each token within its physical block.
internal_offsets = token_indices % PHYSICAL_BLOCK_SIZE
# Addressing of K (5D)
off_k = (
bn[None, :] * stride_k_cache_bs
+ cur_kv_head * stride_k_cache_h
+ (offs_d[:, None] // x) * stride_k_cache_d
+ ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl
+ internal_offsets[None, :] * stride_k_cache_bl
+ (offs_d[:, None] % x) * stride_k_cache_x
)
# [BLOCK_SIZE,D]
# Addressing of V (4D)
off_v = (
bn[:, None] * stride_v_cache_bs
+ cur_kv_head * stride_v_cache_h
+ offs_d[None, :] * stride_v_cache_d
+ offs_bs_n[:, None] * stride_v_cache_bl
+ internal_offsets[:, None] * stride_v_cache_bl
)
if (
......@@ -195,12 +206,12 @@ def _fwd_kernel(
else:
k = k_load
qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
# qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
qk = sm_scale * tl.dot(q, k, input_precision=IN_PRECISION)
qk = tl.where(
(start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")
)
qk *= sm_scale
# qk *= sm_scale
if SLIDING_WINDOW > 0:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence
......@@ -217,14 +228,16 @@ def _fwd_kernel(
(cur_batch_ctx_len + offs_m[:, None]) - (start_n + offs_bs_n[None, :])
< SLIDING_WINDOW,
qk,
-10000,
float("-inf"),
)
# compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp(qk - m_ij[:, None])
p = tl.where(m_ij[:, None] == float("-inf"), 0.0, p)
l_ij = tl.sum(p, axis=1)
alpha = tl.exp(m_i - m_ij)
alpha = tl.where(m_i == float("-inf"), 0.0, alpha)
acc = acc * alpha[:, None]
# update acc
......@@ -293,14 +306,17 @@ def _fwd_kernel(
qk = tl.where(
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
qk,
-10000,
float("-inf"),
)
# compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp(qk - m_ij[:, None])
p = tl.where(m_ij[:, None] == float("-inf"), 0.0, p)
l_ij = tl.sum(p, axis=1)
alpha = tl.exp(m_i - m_ij)
# To prevent NaN from appearing in the first round
alpha = tl.where(m_i == float("-inf"), 0.0, alpha)
acc = acc * alpha[:, None]
# update acc
......@@ -317,7 +333,7 @@ def _fwd_kernel(
l_i = l_i * alpha + l_ij
m_i = m_ij
acc = acc / l_i[:, None]
acc = acc / (l_i[:, None] + 1e-10)
# initialize pointers to output
off_o = (
......@@ -637,6 +653,7 @@ def context_attention_fwd(
skip_decode=False,
fp8_out_scale=None,
sinks=None,
is_block_table_ptr: bool = False,
):
q_dtype_is_f32 = q.dtype is torch.float32
......@@ -689,6 +706,19 @@ def context_attention_fwd(
if sliding_window is None or sliding_window <= 0:
sliding_window = 0
if is_block_table_ptr:
kv_element_size = k_cache.element_size()
block_byte_stride = k_cache.stride(0) * kv_element_size
# The physical starting point of the obtained KV Cache Pool
base_addr = k_cache.data_ptr()
mask = b_loc > 0
processed_b_loc = torch.where(
mask, (b_loc - base_addr) // block_byte_stride, b_loc
).to(torch.int32)
else:
processed_b_loc = b_loc.to(torch.int32)
if alibi_slopes is not None:
assert sinks is None, "Sinks arg is not supported with alibi"
assert fp8_out_scale is None, "FP8 output not supported with alibi"
......@@ -752,7 +782,24 @@ def context_attention_fwd(
max_seq_len = 0 if max_seq_len is None else max_seq_len
extra_kargs = {}
if current_platform.is_rocm():
extra_kargs = {"kpack": 1, "waves_per_eu": 2}
extra_kargs = {}
real_block_size = v_cache.shape[3]
is_pow2 = real_block_size > 0 and (real_block_size & (real_block_size - 1) == 0)
# For standard models involving powers of 2,
# follow the original logic (Llama 128/64)
# For non-standard models (Qwen3-next block_size 544), set to 32.
if is_pow2:
BLOCK_M = 128
BLOCK_N = 64
else:
BLOCK_M = 32
BLOCK_N = 32
# TRITON_BLOCK_SIZE is kept at 32 to ensure
# correct alignment logic when the kernel handles
# non-standard sizes (such as 544).
TRITON_BLOCK_SIZE = 32
grid_fn = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"]))
_fwd_kernel[grid_fn](
......@@ -762,7 +809,7 @@ def context_attention_fwd(
k_cache,
v_cache,
sinks,
b_loc,
processed_b_loc,
sm_scale,
k_scale,
v_scale,
......@@ -771,8 +818,8 @@ def context_attention_fwd(
b_seq_len,
k_cache.shape[4],
o,
b_loc.stride(0),
b_loc.stride(1),
processed_b_loc.stride(0),
processed_b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
......@@ -785,16 +832,17 @@ def context_attention_fwd(
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(4), # [num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(3), # [num_blocks, num_kv_heads, head_size, block_size]
BLOCK_SIZE=v_cache.shape[3],
stride_k_cache_bs=k_cache.stride(0),
stride_k_cache_h=k_cache.stride(1),
stride_k_cache_d=k_cache.stride(2),
stride_k_cache_bl=k_cache.stride(3),
stride_k_cache_x=k_cache.stride(4),
stride_v_cache_bs=v_cache.stride(0),
stride_v_cache_h=v_cache.stride(1),
stride_v_cache_d=v_cache.stride(2),
stride_v_cache_bl=v_cache.stride(3),
BLOCK_SIZE=TRITON_BLOCK_SIZE,
PHYSICAL_BLOCK_SIZE=real_block_size,
num_queries_per_kv=num_queries_per_kv,
IN_PRECISION=IN_PRECISION,
BLOCK_DMODEL=Lk,
......@@ -802,8 +850,8 @@ def context_attention_fwd(
SLIDING_WINDOW=sliding_window,
SKIP_DECODE=skip_decode,
USE_FP8=fp8_out_scale is not None,
BLOCK_M=128,
BLOCK_N=64,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_unroll_cache=4,
num_unroll_request=1,
num_warps=4,
......
......@@ -20,10 +20,15 @@ def reshape_and_cache_kernel_flash(
key_stride: tl.int64,
value_stride: tl.int64,
block_stride: tl.int64,
head_stride: tl.int64,
dim_stride_k: tl.int64,
dim_stride_v: tl.int64,
page_stride: tl.int64,
num_heads: tl.constexpr,
head_size: tl.constexpr,
block_size: tl.constexpr,
x: tl.constexpr,
USE_HEAD_MAJOR_LAYOUT: tl.constexpr,
# FP8 flags
FP8_KV_CACHE: tl.constexpr,
# tune parameters
......@@ -35,17 +40,38 @@ def reshape_and_cache_kernel_flash(
# Padding token that should be ignored.
return
tile_i = tl.program_id(axis=1)
tile_offs = tl.arange(0, TILE_SIZE)
tile_pos = tile_i * TILE_SIZE + tile_offs
block_idx = slot_idx // block_size
block_offset = slot_idx % block_size
tile_i = tl.program_id(axis=1)
tile_offs = tl.arange(0, TILE_SIZE)
tile_pos = tile_i * TILE_SIZE + tile_offs
src_key_idx = token_idx * key_stride
src_value_idx = token_idx * value_stride
tgt_idx = block_idx * block_stride + block_offset * page_stride
if USE_HEAD_MAJOR_LAYOUT:
# Decompose the tile index back into head and dim coordinates.
cur_head = tile_pos // head_size
cur_dim = tile_pos % head_size
# Value addressing (4D): [Block, Head, Dim, Slot]
tgt_idx_v = (
block_idx * block_stride
+ cur_head * head_stride
+ cur_dim * dim_stride_v
+ block_offset * 1
)
# Key addressing (5D): [Block, Head, Dim//8, Slot, 8]
tgt_idx_k = (
block_idx * block_stride
+ cur_head * head_stride
+ (cur_dim // x) * dim_stride_k
+ block_offset * x
+ (cur_dim % x)
)
else:
tgt_base = block_idx * block_stride + block_offset * page_stride
tgt_idx_k = tgt_base + tile_pos
tgt_idx_v = tgt_base + tile_pos
# [TILE_SIZE]
key_load = tl.load(
......@@ -73,12 +99,12 @@ def reshape_and_cache_kernel_flash(
value_tile = value_load
tl.store(
key_cache_ptr + tgt_idx + tile_pos,
key_cache_ptr + tgt_idx_k,
key_tile,
mask=tile_pos < (num_heads * head_size),
)
tl.store(
value_cache_ptr + tgt_idx + tile_pos,
value_cache_ptr + tgt_idx_v,
value_tile,
mask=tile_pos < (num_heads * head_size),
)
......@@ -99,17 +125,26 @@ def triton_reshape_and_cache_flash(
):
num_heads = key.shape[1]
head_size = key.shape[2]
block_size = key_cache.shape[1]
n = num_heads * head_size
use_head_major_layout = key_cache.ndim == 5
if use_head_major_layout:
block_size = key_cache.shape[3]
x = key_cache.shape[4]
head_stride = key_cache.stride(1)
dim_stride_k = key_cache.stride(2)
dim_stride_v = value_cache.stride(2)
else:
block_size = key_cache.shape[1]
x = 1
dim_stride_k = 0
dim_stride_v = 0
head_stride = key_cache.stride()[2]
n = num_heads * head_size
key_stride = key.stride()[0]
value_stride = value.stride()[0]
block_stride = key_cache.stride()[0]
page_stride = key_cache.stride()[1]
head_stride = key_cache.stride()[2]
assert head_stride == head_size, "only continous heads are supported"
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
)
......@@ -171,10 +206,15 @@ def triton_reshape_and_cache_flash(
key_stride=key_stride,
value_stride=value_stride,
block_stride=block_stride,
head_stride=head_stride,
dim_stride_k=dim_stride_k,
dim_stride_v=dim_stride_v,
page_stride=page_stride,
num_heads=num_heads,
head_size=head_size,
block_size=block_size,
x=x,
USE_HEAD_MAJOR_LAYOUT=use_head_major_layout,
# FP8 flags
FP8_KV_CACHE=FP8_KV_CACHE,
# autotune parameters
......
......@@ -15,6 +15,9 @@ from vllm.attention.backends.abstract import (
)
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
......@@ -321,16 +324,38 @@ class RocmAttentionImpl(AttentionImpl):
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# Get the actual block_size from value_cache
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
# Determine if it is a power of 2
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
if is_pow2:
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
# force using our modified Triton logic
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment