Commit 656944ac authored by yangql's avatar yangql
Browse files

增加triton的indexer的kcahche读写操作

parent 12b5bcb1
......@@ -9,12 +9,14 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.platforms.rocm import get_gcn_arch_name
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadata,
)
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import indexer_k_bf16_cache_triton, cp_gather_indexer_k_bf16_cache_triton
from vllm.v1.worker.workspace import current_workspace_manager
from lightop import op, gemmopt
......@@ -73,7 +75,8 @@ def sparse_attn_indexer(
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
num_tokens = slot_mapping.shape[0]
k = k[:num_tokens]
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
ops.indexer_k_quant_and_cache(
k,
......@@ -82,7 +85,12 @@ def sparse_attn_indexer(
quant_block_size,
scale_fmt,
)
else:
indexer_k_bf16_cache_triton(
k,
kv_cache,
slot_mapping,
)
topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
......@@ -90,7 +98,7 @@ def sparse_attn_indexer(
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else k.dtype,),
((total_seq_lens, head_dim), fp8_dtype if not current_platform.is_rocm() or get_gcn_arch_name == "gfx938" else k.dtype,),
((total_seq_lens, 4), torch.uint8),
)
for chunk in prefill_metadata.chunks:
......@@ -112,7 +120,7 @@ def sparse_attn_indexer(
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
elif torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
elif get_gcn_arch_name == "gfx938":
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
......@@ -135,15 +143,23 @@ def sparse_attn_indexer(
k_scale.view(torch.float32).flatten(),
True
)
else:
else:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
cp_gather_indexer_k_bf16_cache_triton(
kv_cache,
k_fp8,
chunk.block_table,
chunk.cu_seq_lens,
)
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k,
k_fp8,
weights[chunk.token_start:chunk.token_end].to(torch.float32),
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0],
k.shape[0],
k_fp8.shape[0],
q_fp8.shape[1],
q_fp8.shape[2],
None,
......
......@@ -216,6 +216,256 @@ def cp_gather_indexer_k_quant_cache_triton(
head_tile_size,
)
@triton.jit
def _indexer_k_bf16_cache_kernel(
k_ptr, # [num_tokens, head_dim] (bf16)
kv_cache_ptr, # [n_blks, block_size, head_dim] (bf16)
slot_mapping_ptr, # [num_tokens]
kv_cache_stride, # KV Cache 第一维的stride
block_size: tl.constexpr,
num_tokens: tl.constexpr,
head_dim: tl.constexpr,
LAYOUT: tl.constexpr,
BLOCK_TILE_SIZE: tl.constexpr,
HEAD_TILE_SIZE: tl.constexpr,
):
"""
Triton 核函数:将 BF16 类型的 K 张量写入 KV Cache(
"""
tid = tl.program_id(0)
# 边界检查:超出 token 范围直接返回
if tid >= num_tokens:
return
# 定义头维度索引偏移(覆盖整个 head_dim)
offset = tl.arange(0, head_dim)
# 计算输入 K 张量的源指针偏移
src_ptr = k_ptr + tid * head_dim
# 加载当前 token 对应的 cache slot ID
slot_id = tl.load(slot_mapping_ptr + tid)
# 无效 slot(-1)直接返回
if slot_id < 0:
return
# 计算 block ID 和块内偏移
block_id = slot_id // block_size
block_offset = slot_id % block_size
# 分块相关的偏移计算(兼容 SHUFFLE 布局)
tile_block_id = block_offset // BLOCK_TILE_SIZE
tile_block_offset = block_offset % BLOCK_TILE_SIZE
# 根据布局计算 KV Cache 的目标指针偏移
if LAYOUT == "SHUFFLE":
# SHUFFLE 布局的偏移计算
tile_offset = (
offset // HEAD_TILE_SIZE * BLOCK_TILE_SIZE * HEAD_TILE_SIZE
+ offset % HEAD_TILE_SIZE
)
dst_ptr = (
kv_cache_ptr
+ block_id * kv_cache_stride
+ tile_block_id * BLOCK_TILE_SIZE * head_dim
+ tile_block_offset * HEAD_TILE_SIZE
)
else:
# NHD 标准布局
tile_offset = offset
dst_ptr = (
kv_cache_ptr + block_id * kv_cache_stride + block_offset * head_dim
)
val = tl.load(src_ptr + offset)
tl.store(dst_ptr + tile_offset, val)
def indexer_k_bf16_cache_triton(
k: torch.Tensor,
kv_cache: torch.Tensor, # [num_blocks, block_size, head_dim] (bf16)
slot_mapping: torch.Tensor,
block_tile_size=16,
head_tile_size=16,
):
"""
将 BF16 类型的 K 张量写入 BF16 类型的 KV Cache
Args:
k: 输入 K 张量 [num_tokens, head_dim] (bf16)
kv_cache: KV Cache 张量 [num_blocks, block_size, head_dim] (bf16)
slot_mapping: token 到 cache slot 的映射 [num_tokens]
block_tile_size: 块分块大小
head_tile_size: 头维度分块大小
"""
# 输入类型校验
assert k.dtype == torch.bfloat16, "k 必须是 bf16 类型"
assert kv_cache.dtype == torch.bfloat16, "kv_cache 必须是 bf16 类型"
# 解析张量维度
num_blocks = kv_cache.shape[0]
block_size = kv_cache.shape[1]
head_dim = k.shape[-1]
num_tokens = slot_mapping.shape[0]
# 验证维度合法性
assert kv_cache.shape[2] == head_dim, "kv_cache 的 head_dim 必须与 k 一致"
# 重塑 KV Cache 为二维(便于指针计算)
kv_cache_2d = kv_cache.view(num_blocks, -1) # [num_blocks, block_size * head_dim]
# 调整 head_tile_size(兼容原逻辑,按字节数归一化)
head_tile_size = head_tile_size // kv_cache.element_size()
# 配置 Triton 核函数的 grid(每个 token 一个 program)
grid = (num_tokens,)
_indexer_k_bf16_cache_kernel[grid](
k,
kv_cache_2d,
slot_mapping,
kv_cache_2d.stride(0),
block_size,
num_tokens,
head_dim,
"NHD", # 布局类型
block_tile_size,
head_tile_size,
)
@triton.jit
def _cp_gather_indexer_k_bf16_cache_kernel(
kv_cache_ptr, # [num_blocks, block_size * head_dim] (bf16)
k_bf16_ptr, # [num_tokens, head_dim] (bf16)
block_table_ptr,
cu_seq_lens_ptr,
block_size: tl.constexpr,
batch_size: tl.constexpr,
num_blocks_per_seq: tl.constexpr,
kv_cache_stride: tl.constexpr,
head_dim: tl.constexpr,
num_tokens: tl.constexpr,
BLOCK_TILE_SIZE: tl.constexpr,
HEAD_TILE_SIZE: tl.constexpr,
):
"""
Triton 核函数 BF16 K Cache 收集
"""
token_idx = tl.program_id(0)
# 边界检查:超出 token 范围直接返回
if token_idx >= num_tokens:
return
# 定义头维度索引偏移(覆盖整个 head_dim)
head_offset = tl.arange(0, head_dim)
batch_idx = tl.full((), -1, dtype=tl.int32)
# 遍历所有 batch(Triton 支持有限循环,需固定循环次数)
for b in tl.static_range(batch_size):
# 加载当前 batch 的序列起始/结束位置
seq_start = tl.load(cu_seq_lens_ptr + b)
seq_end = tl.load(cu_seq_lens_ptr + b + 1)
# 条件判断:当前 token 是否属于该 batch
is_in_batch = (token_idx >= seq_start) & (token_idx < seq_end)
# 条件赋值:如果属于该 batch,更新 batch_idx(替代 break)
batch_idx = tl.where(is_in_batch, b, batch_idx)
# 无效的 batch ID(token 不在任何序列中),直接返回
if batch_idx == -1:
return
# --------------------------
# 计算序列内偏移和 block 索引
# --------------------------
# token 在所属序列内的相对偏移
seq_start = tl.load(cu_seq_lens_ptr + batch_idx)
inbatch_seq_idx = token_idx - seq_start
# 计算该 token 对应的 block 索引(block_table 中的位置)
block_table_id = inbatch_seq_idx // block_size
# 边界检查:block 索引超出范围则返回
if block_table_id >= num_blocks_per_seq:
return
# 计算 block_table 中的内存偏移并加载 block ID
block_table_offset = batch_idx * num_blocks_per_seq + block_table_id
block_id = tl.load(block_table_ptr + block_table_offset)
# 计算 token 在 block 内的偏移
block_offset = inbatch_seq_idx % block_size
# --------------------------
# 计算内存偏移
# --------------------------
# KV Cache 源偏移:block_id * 块步长 + 块内偏移 * head_dim
src_block_offset = block_id * kv_cache_stride
src_inblock_offset = src_block_offset + block_offset * head_dim
# 输出张量目标偏移
dst_inblock_offset = token_idx * head_dim
src_ptr = kv_cache_ptr + src_inblock_offset + head_offset
val = tl.load(src_ptr)
dst_ptr = k_bf16_ptr + dst_inblock_offset + head_offset
tl.store(dst_ptr, val)
def cp_gather_indexer_k_bf16_cache_triton(
k_cache: torch.Tensor, # [num_blocks, block_size, head_dim] (bf16)
k_bf16: torch.Tensor, # [num_tokens, head_dim] (bf16)
block_table: torch.Tensor, # [batch_size, num_blocks_per_seq]
cu_seq_lens: torch.Tensor, # [batch_size + 1]
block_tile_size: int = 16,
head_tile_size: int = 16,
):
"""
BF16 K Cache 收集算子
Args:
k_cache: K缓存张量 [num_blocks, block_size, head_dim] (bf16)
k_bf16: 输出张量 [num_tokens, head_dim] (bf16)
block_table: 块表 [batch_size, num_blocks_per_seq]
cu_seq_lens: 序列长度累积数组 [batch_size + 1]
block_tile_size: 块分块大小
head_tile_size: 头维度分块大小
"""
# 输入类型校验
assert k_cache.dtype == torch.bfloat16, "k_cache 必须是 bf16 类型"
assert k_bf16.dtype == torch.bfloat16, "k_bf16 必须是 bf16 类型"
# 解析维度参数
num_tokens = k_bf16.size(0)
block_size = k_cache.size(1)
head_dim = k_bf16.shape[-1]
num_blocks = k_cache.shape[0]
batch_size = block_table.size(0)
num_blocks_per_seq = block_table.size(1)
# 重塑缓存张量(便于指针计算)
k_cache_2d = k_cache.view(num_blocks, -1) # [num_blocks, block_size * head_dim]
# 配置 Triton 核函数的 grid(每个 token 一个 program)
grid = (num_tokens,)
_cp_gather_indexer_k_bf16_cache_kernel[grid](
k_cache_2d,
k_bf16,
block_table,
cu_seq_lens,
block_size,
batch_size,
num_blocks_per_seq,
k_cache_2d.stride(0), # kv_cache stride (block维度)
head_dim,
num_tokens,
block_tile_size,
head_tile_size,
)
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156
def fp8_paged_mqa_logits_torch(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment