Commit 34d497a1 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev_3.21_yql' into 'v0.15.1-dev'

关闭sparse_mla的num_head到64/128的pad,以及添加控制fp8_use_mixed_batch模式的环境变量控制,FP8_USE_MI...

See merge request dcutoolkit/deeplearing/vllm!524
parents 31be48ea ed5b3425
...@@ -297,6 +297,7 @@ if TYPE_CHECKING: ...@@ -297,6 +297,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_RMS_ROPE: bool = False VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT: bool = False VLLM_USE_FUSED_FILL_RMS_CAT: bool = False
VLLM_USE_CAT_MLA: bool = False VLLM_USE_CAT_MLA: bool = False
FP8_USE_MIXED_BATCH: bool = False
VLLM_W8A8_BACKEND: int = 3 VLLM_W8A8_BACKEND: int = 3
VLLM_USE_PP_BALANCE = True VLLM_USE_PP_BALANCE = True
VLLM_MOE_ROUTER_CAPTURE: bool = False VLLM_MOE_ROUTER_CAPTURE: bool = False
...@@ -1825,7 +1826,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1825,7 +1826,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm will use fused cat and mla # vllm will use fused cat and mla
"VLLM_USE_CAT_MLA": "VLLM_USE_CAT_MLA":
lambda: (os.getenv('VLLM_USE_CAT_MLA', 'False').lower() in lambda: (os.getenv('VLLM_USE_CAT_MLA', 'False').lower() in
("true", "1")), ("true", "1")),
# vllm will use fused cat and mla
"FP8_USE_MIXED_BATCH":
lambda: (os.getenv('FP8_USE_MIXED_BATCH', 'False').lower() in
("true", "1")),
# vLLM will use opt merge_aatn_states,not triton # vLLM will use opt merge_aatn_states,not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
......
...@@ -9,12 +9,14 @@ from vllm.forward_context import get_forward_context ...@@ -9,12 +9,14 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import get_gcn_arch_name
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mla.indexer import ( from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadata, DeepseekV32IndexerMetadata,
) )
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import indexer_k_bf16_cache_triton, cp_gather_indexer_k_bf16_cache_triton
from vllm.v1.worker.workspace import current_workspace_manager from vllm.v1.worker.workspace import current_workspace_manager
from lightop import op, gemmopt from lightop import op, gemmopt
...@@ -73,13 +75,8 @@ def sparse_attn_indexer( ...@@ -73,13 +75,8 @@ def sparse_attn_indexer(
has_decode = attn_metadata.num_decodes > 0 has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
# During speculative decoding, k may be padded to the CUDA graph batch
# size while slot_mapping only covers actual tokens. Truncate k to avoid
# out-of-bounds reads in the kernel.
num_tokens = slot_mapping.shape[0] num_tokens = slot_mapping.shape[0]
k = k[:num_tokens] k = k[:num_tokens]
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938": if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
ops.indexer_k_quant_and_cache( ops.indexer_k_quant_and_cache(
k, k,
...@@ -88,7 +85,12 @@ def sparse_attn_indexer( ...@@ -88,7 +85,12 @@ def sparse_attn_indexer(
quant_block_size, quant_block_size,
scale_fmt, scale_fmt,
) )
else:
indexer_k_bf16_cache_triton(
k,
kv_cache,
slot_mapping,
)
topk_indices_buffer[: hidden_states.shape[0]] = -1 topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill: if has_prefill:
prefill_metadata = attn_metadata.prefill prefill_metadata = attn_metadata.prefill
...@@ -96,7 +98,7 @@ def sparse_attn_indexer( ...@@ -96,7 +98,7 @@ def sparse_attn_indexer(
# Get the full shared workspace buffers once (will allocate on first use) # Get the full shared workspace buffers once (will allocate on first use)
workspace_manager = current_workspace_manager() workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous( k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else k.dtype,), ((total_seq_lens, head_dim), fp8_dtype if not current_platform.is_rocm() or get_gcn_arch_name() == "gfx938" else k.dtype,),
((total_seq_lens, 4), torch.uint8), ((total_seq_lens, 4), torch.uint8),
) )
for chunk in prefill_metadata.chunks: for chunk in prefill_metadata.chunks:
...@@ -118,7 +120,7 @@ def sparse_attn_indexer( ...@@ -118,7 +120,7 @@ def sparse_attn_indexer(
chunk.cu_seqlen_ks, chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke, chunk.cu_seqlen_ke,
) )
elif torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938": elif get_gcn_arch_name() == "gfx938":
k_fp8 = k_fp8_full[: chunk.total_seq_lens] k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens] k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache( ops.cp_gather_indexer_k_quant_cache(
...@@ -142,14 +144,22 @@ def sparse_attn_indexer( ...@@ -142,14 +144,22 @@ def sparse_attn_indexer(
True True
) )
else: else:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
cp_gather_indexer_k_bf16_cache_triton(
kv_cache,
k_fp8,
chunk.block_table,
chunk.cu_seq_lens,
)
logits = op.mqa_logits( logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end], q_fp8[chunk.token_start:chunk.token_end],
k, k_fp8,
weights[chunk.token_start:chunk.token_end].to(torch.float32), weights[chunk.token_start:chunk.token_end].to(torch.float32),
chunk.cu_seqlen_ks, chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke, chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0], q_fp8[chunk.token_start:chunk.token_end].shape[0],
k.shape[0], k_fp8.shape[0],
q_fp8.shape[1], q_fp8.shape[1],
q_fp8.shape[2], q_fp8.shape[2],
None, None,
......
...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, ClassVar, Optional ...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, ClassVar, Optional
import numpy as np import numpy as np
import torch import torch
from vllm import envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
...@@ -668,7 +669,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad ...@@ -668,7 +669,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
| FlashMLASparseMetadata.FP8KernelMetadata | FlashMLASparseMetadata.FP8KernelMetadata
| None | None
) = None ) = None
fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL and envs.FP8_USE_MIXED_BATCH
if self.use_fp8_kv_cache: if self.use_fp8_kv_cache:
if fp8_use_mixed_batch: if fp8_use_mixed_batch:
fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm) fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm)
...@@ -924,14 +925,14 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ...@@ -924,14 +925,14 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
padded_num_heads = self.fp8_decode_padded_heads padded_num_heads = self.fp8_decode_padded_heads
# Pad query if needed (kernel only supports h_q = 64 or 128) # Pad query if needed (kernel only supports h_q = 64 or 128)
#if actual_num_heads < padded_num_heads: # if actual_num_heads < padded_num_heads:
# logger.warning_once( # logger.warning_once(
# f"Padding num_heads from {actual_num_heads} to " # f"Padding num_heads from {actual_num_heads} to "
# f"{padded_num_heads} for FP8 sparse decode kernel" # f"{padded_num_heads} for FP8 sparse decode kernel"
# ) # )
# q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3))) # q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3)))
# q_padded[:, :, :actual_num_heads, :] = q # q_padded[:, :, :actual_num_heads, :] = q
# q = q_padded # q = q_padded
out, lse = flash_mla_with_kvcache( out, lse = flash_mla_with_kvcache(
q=q, q=q,
...@@ -964,15 +965,15 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ...@@ -964,15 +965,15 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
# NOTE(Chen): kernel requires num_local_head to be a multiple of # NOTE(Chen): kernel requires num_local_head to be a multiple of
# 64 on hopper and 128 on blackwell # 64 on hopper and 128 on blackwell
if self.num_heads % self.prefill_padding != 0: # if self.num_heads % self.prefill_padding != 0:
assert self.prefill_padding % self.num_heads == 0 # assert self.prefill_padding % self.num_heads == 0
logger.warning_once( # logger.warning_once(
f"Padding num_heads from {self.num_heads} to " # f"Padding num_heads from {self.num_heads} to "
f"{self.prefill_padding} for BF16 sparse prefill kernel" # f"{self.prefill_padding} for BF16 sparse prefill kernel"
) # )
q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2])) # q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
q_padded[:, : self.num_heads, :] = q # q_padded[:, : self.num_heads, :] = q
q = q_padded # q = q_padded
topk_indices = topk_indices.view(num_tokens, 1, -1) topk_indices = topk_indices.view(num_tokens, 1, -1)
output = flash_mla_sparse_fwd( output = flash_mla_sparse_fwd(
......
...@@ -216,6 +216,256 @@ def cp_gather_indexer_k_quant_cache_triton( ...@@ -216,6 +216,256 @@ def cp_gather_indexer_k_quant_cache_triton(
head_tile_size, head_tile_size,
) )
@triton.jit
def _indexer_k_bf16_cache_kernel(
k_ptr, # [num_tokens, head_dim] (bf16)
kv_cache_ptr, # [n_blks, block_size, head_dim] (bf16)
slot_mapping_ptr, # [num_tokens]
kv_cache_stride, # KV Cache 第一维的stride
block_size: tl.constexpr,
num_tokens: tl.constexpr,
head_dim: tl.constexpr,
LAYOUT: tl.constexpr,
BLOCK_TILE_SIZE: tl.constexpr,
HEAD_TILE_SIZE: tl.constexpr,
):
"""
Triton 核函数:将 BF16 类型的 K 张量写入 KV Cache(
"""
tid = tl.program_id(0)
# 边界检查:超出 token 范围直接返回
if tid >= num_tokens:
return
# 定义头维度索引偏移(覆盖整个 head_dim)
offset = tl.arange(0, head_dim)
# 计算输入 K 张量的源指针偏移
src_ptr = k_ptr + tid * head_dim
# 加载当前 token 对应的 cache slot ID
slot_id = tl.load(slot_mapping_ptr + tid)
# 无效 slot(-1)直接返回
if slot_id < 0:
return
# 计算 block ID 和块内偏移
block_id = slot_id // block_size
block_offset = slot_id % block_size
# 分块相关的偏移计算(兼容 SHUFFLE 布局)
tile_block_id = block_offset // BLOCK_TILE_SIZE
tile_block_offset = block_offset % BLOCK_TILE_SIZE
# 根据布局计算 KV Cache 的目标指针偏移
if LAYOUT == "SHUFFLE":
# SHUFFLE 布局的偏移计算
tile_offset = (
offset // HEAD_TILE_SIZE * BLOCK_TILE_SIZE * HEAD_TILE_SIZE
+ offset % HEAD_TILE_SIZE
)
dst_ptr = (
kv_cache_ptr
+ block_id * kv_cache_stride
+ tile_block_id * BLOCK_TILE_SIZE * head_dim
+ tile_block_offset * HEAD_TILE_SIZE
)
else:
# NHD 标准布局
tile_offset = offset
dst_ptr = (
kv_cache_ptr + block_id * kv_cache_stride + block_offset * head_dim
)
val = tl.load(src_ptr + offset)
tl.store(dst_ptr + tile_offset, val)
def indexer_k_bf16_cache_triton(
k: torch.Tensor,
kv_cache: torch.Tensor, # [num_blocks, block_size, head_dim] (bf16)
slot_mapping: torch.Tensor,
block_tile_size=16,
head_tile_size=16,
):
"""
将 BF16 类型的 K 张量写入 BF16 类型的 KV Cache
Args:
k: 输入 K 张量 [num_tokens, head_dim] (bf16)
kv_cache: KV Cache 张量 [num_blocks, block_size, head_dim] (bf16)
slot_mapping: token 到 cache slot 的映射 [num_tokens]
block_tile_size: 块分块大小
head_tile_size: 头维度分块大小
"""
# 输入类型校验
assert k.dtype == torch.bfloat16, "k 必须是 bf16 类型"
assert kv_cache.dtype == torch.bfloat16, "kv_cache 必须是 bf16 类型"
# 解析张量维度
num_blocks = kv_cache.shape[0]
block_size = kv_cache.shape[1]
head_dim = k.shape[-1]
num_tokens = slot_mapping.shape[0]
# 验证维度合法性
assert kv_cache.shape[2] == head_dim, "kv_cache 的 head_dim 必须与 k 一致"
# 重塑 KV Cache 为二维(便于指针计算)
kv_cache_2d = kv_cache.view(num_blocks, -1) # [num_blocks, block_size * head_dim]
# 调整 head_tile_size(兼容原逻辑,按字节数归一化)
head_tile_size = head_tile_size // kv_cache.element_size()
# 配置 Triton 核函数的 grid(每个 token 一个 program)
grid = (num_tokens,)
_indexer_k_bf16_cache_kernel[grid](
k,
kv_cache_2d,
slot_mapping,
kv_cache_2d.stride(0),
block_size,
num_tokens,
head_dim,
"NHD", # 布局类型
block_tile_size,
head_tile_size,
)
@triton.jit
def _cp_gather_indexer_k_bf16_cache_kernel(
kv_cache_ptr, # [num_blocks, block_size * head_dim] (bf16)
k_bf16_ptr, # [num_tokens, head_dim] (bf16)
block_table_ptr,
cu_seq_lens_ptr,
block_size: tl.constexpr,
batch_size: tl.constexpr,
num_blocks_per_seq: tl.constexpr,
kv_cache_stride: tl.constexpr,
head_dim: tl.constexpr,
num_tokens: tl.constexpr,
BLOCK_TILE_SIZE: tl.constexpr,
HEAD_TILE_SIZE: tl.constexpr,
):
"""
Triton 核函数 BF16 K Cache 收集
"""
token_idx = tl.program_id(0)
# 边界检查:超出 token 范围直接返回
if token_idx >= num_tokens:
return
# 定义头维度索引偏移(覆盖整个 head_dim)
head_offset = tl.arange(0, head_dim)
batch_idx = tl.full((), -1, dtype=tl.int32)
# 遍历所有 batch(Triton 支持有限循环,需固定循环次数)
for b in tl.static_range(batch_size):
# 加载当前 batch 的序列起始/结束位置
seq_start = tl.load(cu_seq_lens_ptr + b)
seq_end = tl.load(cu_seq_lens_ptr + b + 1)
# 条件判断:当前 token 是否属于该 batch
is_in_batch = (token_idx >= seq_start) & (token_idx < seq_end)
# 条件赋值:如果属于该 batch,更新 batch_idx(替代 break)
batch_idx = tl.where(is_in_batch, b, batch_idx)
# 无效的 batch ID(token 不在任何序列中),直接返回
if batch_idx == -1:
return
# --------------------------
# 计算序列内偏移和 block 索引
# --------------------------
# token 在所属序列内的相对偏移
seq_start = tl.load(cu_seq_lens_ptr + batch_idx)
inbatch_seq_idx = token_idx - seq_start
# 计算该 token 对应的 block 索引(block_table 中的位置)
block_table_id = inbatch_seq_idx // block_size
# 边界检查:block 索引超出范围则返回
if block_table_id >= num_blocks_per_seq:
return
# 计算 block_table 中的内存偏移并加载 block ID
block_table_offset = batch_idx * num_blocks_per_seq + block_table_id
block_id = tl.load(block_table_ptr + block_table_offset)
# 计算 token 在 block 内的偏移
block_offset = inbatch_seq_idx % block_size
# --------------------------
# 计算内存偏移
# --------------------------
# KV Cache 源偏移:block_id * 块步长 + 块内偏移 * head_dim
src_block_offset = block_id * kv_cache_stride
src_inblock_offset = src_block_offset + block_offset * head_dim
# 输出张量目标偏移
dst_inblock_offset = token_idx * head_dim
src_ptr = kv_cache_ptr + src_inblock_offset + head_offset
val = tl.load(src_ptr)
dst_ptr = k_bf16_ptr + dst_inblock_offset + head_offset
tl.store(dst_ptr, val)
def cp_gather_indexer_k_bf16_cache_triton(
k_cache: torch.Tensor, # [num_blocks, block_size, head_dim] (bf16)
k_bf16: torch.Tensor, # [num_tokens, head_dim] (bf16)
block_table: torch.Tensor, # [batch_size, num_blocks_per_seq]
cu_seq_lens: torch.Tensor, # [batch_size + 1]
block_tile_size: int = 16,
head_tile_size: int = 16,
):
"""
BF16 K Cache 收集算子
Args:
k_cache: K缓存张量 [num_blocks, block_size, head_dim] (bf16)
k_bf16: 输出张量 [num_tokens, head_dim] (bf16)
block_table: 块表 [batch_size, num_blocks_per_seq]
cu_seq_lens: 序列长度累积数组 [batch_size + 1]
block_tile_size: 块分块大小
head_tile_size: 头维度分块大小
"""
# 输入类型校验
assert k_cache.dtype == torch.bfloat16, "k_cache 必须是 bf16 类型"
assert k_bf16.dtype == torch.bfloat16, "k_bf16 必须是 bf16 类型"
# 解析维度参数
num_tokens = k_bf16.size(0)
block_size = k_cache.size(1)
head_dim = k_bf16.shape[-1]
num_blocks = k_cache.shape[0]
batch_size = block_table.size(0)
num_blocks_per_seq = block_table.size(1)
# 重塑缓存张量(便于指针计算)
k_cache_2d = k_cache.view(num_blocks, -1) # [num_blocks, block_size * head_dim]
# 配置 Triton 核函数的 grid(每个 token 一个 program)
grid = (num_tokens,)
_cp_gather_indexer_k_bf16_cache_kernel[grid](
k_cache_2d,
k_bf16,
block_table,
cu_seq_lens,
block_size,
batch_size,
num_blocks_per_seq,
k_cache_2d.stride(0), # kv_cache stride (block维度)
head_dim,
num_tokens,
block_tile_size,
head_tile_size,
)
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156 # Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156
def fp8_paged_mqa_logits_torch( def fp8_paged_mqa_logits_torch(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment