Unverified Commit 6c20e89c authored by Pleaplusone's avatar Pleaplusone Committed by GitHub
Browse files

[ROCm][Deepseekv3.2] Refactor Sparse Indexer as CustomOp (#29287)


Signed-off-by: default avatarganyi <ygan@amd.com>
parent 85f55c94
...@@ -9,6 +9,10 @@ from torch._ops import OpOverload ...@@ -9,6 +9,10 @@ from torch._ops import OpOverload
import vllm.envs as envs import vllm.envs as envs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
rocm_aiter_sparse_attn_indexer,
rocm_aiter_sparse_attn_indexer_fake,
)
_FP8_DTYPE = current_platform.fp8_dtype() _FP8_DTYPE = current_platform.fp8_dtype()
...@@ -1091,6 +1095,14 @@ class rocm_aiter_ops: ...@@ -1091,6 +1095,14 @@ class rocm_aiter_ops:
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op(
op_name="rocm_aiter_sparse_attn_indexer",
op_func=rocm_aiter_sparse_attn_indexer,
mutates_args=["topk_indices_buffer"],
fake_impl=rocm_aiter_sparse_attn_indexer_fake,
dispatch_key=current_platform.dispatch_key,
)
_OPS_REGISTERED = True _OPS_REGISTERED = True
@staticmethod @staticmethod
......
...@@ -611,6 +611,7 @@ class CompilationConfig: ...@@ -611,6 +611,7 @@ class CompilationConfig:
"vllm::gdn_attention_core", "vllm::gdn_attention_core",
"vllm::kda_attention", "vllm::kda_attention",
"vllm::sparse_attn_indexer", "vllm::sparse_attn_indexer",
"vllm::rocm_aiter_sparse_attn_indexer",
] ]
def compute_hash(self) -> str: def compute_hash(self) -> str:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Custom Sparse Attention Indexer layers."""
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadata,
)
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.v1.worker.workspace import current_workspace_manager
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
logger = init_logger(__name__)
def sparse_attn_indexer(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: str | None,
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: torch.Tensor,
) -> torch.Tensor:
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype()
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
# Reserve workspace for indexer during profiling run
current_workspace_manager().get_simultaneous(
((total_seq_lens, head_dim), torch.float8_e4m3fn),
((total_seq_lens, 4), torch.uint8),
)
return sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
kv_cache,
q_fp8,
k,
weights,
quant_block_size,
scale_fmt,
topk_tokens,
head_dim,
max_model_len,
total_seq_lens,
topk_indices_buffer,
)
attn_metadata = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
ops.indexer_k_quant_and_cache(
k,
kv_cache,
slot_mapping,
quant_block_size,
scale_fmt,
)
topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype),
((total_seq_lens, 4), torch.uint8),
)
for chunk in prefill_metadata.chunks:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)
logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[
chunk.token_start : chunk.token_end, :topk_tokens
]
torch.ops._C.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if has_decode:
decode_metadata = attn_metadata.decode
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache = kv_cache.unsqueeze(-2)
decode_lens = decode_metadata.decode_lens
if decode_metadata.requires_padding:
# pad in edge case where we have short chunked prefill length <
# decode_threshold since we unstrictly split
# prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens)
padded_q_fp8_decode_tokens = pack_seq_triton(
q_fp8[:num_decode_tokens], decode_lens
)
else:
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
decode_lens.shape[0], -1, *q_fp8.shape[1:]
)
# TODO: move and optimize below logic with triton kernels
batch_size = padded_q_fp8_decode_tokens.shape[0]
next_n = padded_q_fp8_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
logits = fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
torch.ops._C.top_k_per_row_decode(
logits,
next_n,
decode_metadata.seq_lens,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if decode_metadata.requires_padding:
# if padded, we need to unpack
# the topk indices removing padded tokens
topk_indices = unpack_seq_triton(
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
decode_lens,
)
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
topk_indices
)
return topk_indices_buffer
def sparse_attn_indexer_fake(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: str | None,
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: torch.Tensor | None,
) -> torch.Tensor:
return topk_indices_buffer
direct_register_custom_op(
op_name="sparse_attn_indexer",
op_func=sparse_attn_indexer,
mutates_args=["topk_indices_buffer"],
fake_impl=sparse_attn_indexer_fake,
dispatch_key=current_platform.dispatch_key,
)
@CustomOp.register("sparse_attn_indexer")
class SparseAttnIndexer(CustomOp):
"""Sparse Attention Indexer Custom Op Layer. This layer is extracted as a
separate custom op since it involves heavy custom kernels like `mqa_logits`,
`paged_mqa_logits` and `top_k_per_row`, etc. Those kernels maybe requires
specific memory layout or implementation for different hardware backends to
achieve optimal performance.
For now, the default native path will use CUDA backend path. Other platform
may requires add the corresponding Custom Op name `sparse_attn_indexer` to
`custom_ops` in `CompilationConfig` to enable the platform specific path.
"""
def __init__(
self,
k_cache,
quant_block_size: int,
scale_fmt: str,
topk_tokens: int,
head_dim: int,
max_model_len: int,
max_total_seq_len: int,
topk_indices_buffer: torch.Tensor,
):
super().__init__()
self.k_cache = k_cache
self.quant_block_size = quant_block_size
self.scale_fmt = scale_fmt
self.topk_tokens = topk_tokens
self.head_dim = head_dim
self.max_model_len = max_model_len
self.max_total_seq_len = max_total_seq_len
self.topk_indices_buffer = topk_indices_buffer
def forward_native(
self,
hidden_states: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
):
if current_platform.is_cuda():
return self.forward_cuda(hidden_states, q_fp8, k, weights)
elif current_platform.is_rocm():
return self.forward_hip(hidden_states, q_fp8, k, weights)
else:
raise NotImplementedError(
"SparseAttnIndexer native forward is only implemented for "
"CUDA and ROCm platform."
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
):
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
def forward_hip(
self,
hidden_states: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
):
if rocm_aiter_ops.is_enabled():
return torch.ops.vllm.rocm_aiter_sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
else:
raise RuntimeError(
"Sparse attention indexer ROCm custom op requires ROCm "
"Aiter ops to be enabled."
)
...@@ -43,7 +43,6 @@ from vllm.distributed import ( ...@@ -43,7 +43,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
...@@ -63,6 +62,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -63,6 +62,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
...@@ -74,16 +74,11 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -74,16 +74,11 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionBackend from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.mla.indexer import ( from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend, DeepseekV32IndexerBackend,
DeepseekV32IndexerMetadata,
) )
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
from vllm.v1.worker.workspace import current_workspace_manager
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
from .utils import ( from .utils import (
...@@ -94,11 +89,6 @@ from .utils import ( ...@@ -94,11 +89,6 @@ from .utils import (
maybe_prefix, maybe_prefix,
) )
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -599,213 +589,6 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase): ...@@ -599,213 +589,6 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
return DeepseekV32IndexerBackend return DeepseekV32IndexerBackend
def sparse_attn_indexer(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: str | None,
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: torch.Tensor | None,
) -> torch.Tensor:
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype()
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
# Reserve workspace for indexer during profiling run
current_workspace_manager().get_simultaneous(
((total_seq_lens, head_dim), torch.float8_e4m3fn),
((total_seq_lens, 4), torch.uint8),
)
return sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
kv_cache,
q_fp8,
k,
weights,
quant_block_size,
scale_fmt,
topk_tokens,
head_dim,
max_model_len,
total_seq_lens,
topk_indices_buffer,
)
attn_metadata = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
ops.indexer_k_quant_and_cache(
k,
kv_cache,
slot_mapping,
quant_block_size,
scale_fmt,
)
topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype),
((total_seq_lens, 4), torch.uint8),
)
for chunk in prefill_metadata.chunks:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)
fp8_mqa_logits_func = fp8_mqa_logits
if current_platform.is_rocm():
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
rocm_fp8_mqa_logits,
)
fp8_mqa_logits_func = rocm_fp8_mqa_logits
logits = fp8_mqa_logits_func(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[
chunk.token_start : chunk.token_end, :topk_tokens
]
torch.ops._C.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if has_decode:
decode_metadata = attn_metadata.decode
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache = kv_cache.unsqueeze(-2)
decode_lens = decode_metadata.decode_lens
if decode_metadata.requires_padding:
# pad in edge case where we have short chunked prefill length <
# decode_threshold since we unstrictly split
# prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens)
# [num_decode_tokens, n_head, head_dim] -> [bs, 1+next_n, n_head, head_dim]
padded_q_fp8_decode_tokens = pack_seq_triton(
q_fp8[:num_decode_tokens], decode_lens
)
# [num_decode_tokens, n_head] -> [bs, 1+next_n, n_head]
padded_weights = pack_seq_triton(weights[:num_decode_tokens], decode_lens)
# [bs, 1+next_n, n_head] -> [bs * next_n, n_head]
padded_weights = padded_weights.flatten(0, 1)
else:
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
decode_lens.shape[0], -1, *q_fp8.shape[1:]
)
padded_weights = weights
# TODO: move and optimize below logic with triton kernels
batch_size = padded_q_fp8_decode_tokens.shape[0]
next_n = padded_q_fp8_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
fp8_paged_mqa_logits_func = fp8_paged_mqa_logits
if current_platform.is_rocm():
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
rocm_fp8_paged_mqa_logits,
)
fp8_paged_mqa_logits_func = rocm_fp8_paged_mqa_logits
logits = fp8_paged_mqa_logits_func(
padded_q_fp8_decode_tokens,
kv_cache,
padded_weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
torch.ops._C.top_k_per_row_decode(
logits,
next_n,
decode_metadata.seq_lens,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if decode_metadata.requires_padding:
# if padded, we need to unpack
# the topk indices removing padded tokens
topk_indices = unpack_seq_triton(
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
decode_lens,
)
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
topk_indices
)
return topk_indices_buffer
def sparse_attn_indexer_fake(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: str | None,
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: torch.Tensor | None,
) -> torch.Tensor:
return topk_indices_buffer
direct_register_custom_op(
op_name="sparse_attn_indexer",
op_func=sparse_attn_indexer,
mutates_args=["topk_indices_buffer"],
fake_impl=sparse_attn_indexer_fake,
dispatch_key=current_platform.dispatch_key,
)
class Indexer(nn.Module): class Indexer(nn.Module):
def __init__( def __init__(
self, self,
...@@ -870,6 +653,16 @@ class Indexer(nn.Module): ...@@ -870,6 +653,16 @@ class Indexer(nn.Module):
from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size
self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config) self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)
self.indexer_op = SparseAttnIndexer(
self.k_cache,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
def forward( def forward(
self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb
...@@ -892,6 +685,8 @@ class Indexer(nn.Module): ...@@ -892,6 +685,8 @@ class Indexer(nn.Module):
q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim) q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim)
k_pe = k_pe.reshape(-1, 1, self.rope_dim) k_pe = k_pe.reshape(-1, 1, self.rope_dim)
# `rotary_emb` is shape-preserving; `q_pe` is already
# [num_tokens, n_head, rope_dim].
q = torch.cat([q_pe, q_nope], dim=-1) q = torch.cat([q_pe, q_nope], dim=-1)
# `k_pe` is [num_tokens, 1, rope_dim] (MQA). # `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
...@@ -913,21 +708,7 @@ class Indexer(nn.Module): ...@@ -913,21 +708,7 @@ class Indexer(nn.Module):
) )
weights = weights.squeeze(-1) weights = weights.squeeze(-1)
return torch.ops.vllm.sparse_attn_indexer( return self.indexer_op(hidden_states, q_fp8, k, weights)
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
class DeepseekV2MLAAttention(nn.Module): class DeepseekV2MLAAttention(nn.Module):
......
...@@ -480,6 +480,9 @@ class RocmPlatform(Platform): ...@@ -480,6 +480,9 @@ class RocmPlatform(Platform):
): ):
compilation_config.custom_ops.append("+grouped_topk") compilation_config.custom_ops.append("+grouped_topk")
# Default dispatch to rocm's sparse_attn_indexer implementation
compilation_config.custom_ops.append("+sparse_attn_indexer")
@classmethod @classmethod
def verify_model_arch(cls, model_arch: str) -> None: def verify_model_arch(cls, model_arch: str) -> None:
if model_arch in _ROCM_UNSUPPORTED_MODELS: if model_arch in _ROCM_UNSUPPORTED_MODELS:
......
...@@ -63,6 +63,7 @@ class DeepseekV32IndexerPrefillChunkMetadata: ...@@ -63,6 +63,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
cu_seqlen_ks: torch.Tensor cu_seqlen_ks: torch.Tensor
cu_seqlen_ke: torch.Tensor cu_seqlen_ke: torch.Tensor
cu_seq_lens: torch.Tensor cu_seq_lens: torch.Tensor
token_to_seq: torch.Tensor
total_seq_lens: int total_seq_lens: int
token_start: int token_start: int
token_end: int token_end: int
...@@ -234,6 +235,10 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -234,6 +235,10 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
token_start = query_start_loc_cpu[reqs_start].item() token_start = query_start_loc_cpu[reqs_start].item()
token_end = query_start_loc_cpu[reqs_end].item() token_end = query_start_loc_cpu[reqs_end].item()
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum() total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32)
token_to_seq = torch.repeat_interleave(
seq_idx, seq_lens_cpu[reqs_start:reqs_end]
).to(self.device)
assert total_seq_lens <= self.max_prefill_buffer_size assert total_seq_lens <= self.max_prefill_buffer_size
cu_seq_lens = ( cu_seq_lens = (
torch.cat( torch.cat(
...@@ -249,6 +254,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -249,6 +254,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
cu_seqlen_ks=cu_seqlen_ks, cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke, cu_seqlen_ke=cu_seqlen_ke,
cu_seq_lens=cu_seq_lens, cu_seq_lens=cu_seq_lens,
token_to_seq=token_to_seq,
total_seq_lens=total_seq_lens, total_seq_lens=total_seq_lens,
block_table=block_table[reqs_start:reqs_end], block_table=block_table[reqs_start:reqs_end],
token_start=token_start, token_start=token_start,
......
...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( ...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBaseImpl, MLACommonBaseImpl,
get_mla_dims, get_mla_dims,
) )
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
...@@ -33,6 +34,48 @@ if TYPE_CHECKING: ...@@ -33,6 +34,48 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
@triton.jit
def fetch_id_to_ragged_kernel(
in_tensor_ptr, # [num_seq, topk]
cumsum_ptr, # [num_seq + 1]
out_tensor_ptr, # [max_num_seq * topk]
in_tensor_ptr_stride,
TOPK: tl.constexpr,
TOKEN_NUM: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
seq_id = tl.program_id(0)
block_id = tl.program_id(1)
offset = tl.arange(0, BLOCK_SIZE)
token_start = tl.load(cumsum_ptr + seq_id)
token_end = tl.load(cumsum_ptr + seq_id + 1)
token_num = token_end - token_start
row_offset = block_id * BLOCK_SIZE
if row_offset >= token_num:
return
in_tensor_offset = seq_id * in_tensor_ptr_stride + row_offset + offset
in_tensor_mask = (row_offset + offset) < TOPK
in_tensor_val = tl.load(in_tensor_ptr + in_tensor_offset, mask=in_tensor_mask)
out_tensor_offset = token_start + row_offset + offset
out_tensor_mask = (out_tensor_offset < token_end) & in_tensor_mask
tl.store(out_tensor_ptr + out_tensor_offset, in_tensor_val, mask=out_tensor_mask)
def fetch_id_to_ragged_triton(
in_tensor: torch.Tensor, cumsum: torch.Tensor, out_tensor: torch.Tensor, topk
):
num_tokens = in_tensor.size(0)
block_size = 64
num_block_per_row = triton.cdiv(topk, block_size)
grid = (
num_tokens,
num_block_per_row,
)
fetch_id_to_ragged_kernel[grid](
in_tensor, cumsum, out_tensor, in_tensor.stride(0), topk, num_tokens, block_size
)
class ROCMAiterMLASparseBackend(AttentionBackend): class ROCMAiterMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
...@@ -83,6 +126,13 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata): ...@@ -83,6 +126,13 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata):
block_table: torch.Tensor block_table: torch.Tensor
req_id_per_token: torch.Tensor req_id_per_token: torch.Tensor
qo_indptr: torch.Tensor
paged_kv_last_page_len: torch.Tensor
paged_kv_indices: torch.Tensor
paged_kv_indptr: torch.Tensor
paged_kv_indptr_rest: torch.Tensor
block_size: int = 1 block_size: int = 1
topk_tokens: int = 2048 topk_tokens: int = 2048
...@@ -91,7 +141,7 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata): ...@@ -91,7 +141,7 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata):
class ROCMAiterMLASparseMetadataBuilder( class ROCMAiterMLASparseMetadataBuilder(
AttentionMetadataBuilder[ROCMAiterMLASparseMetadata] AttentionMetadataBuilder[ROCMAiterMLASparseMetadata]
): ):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
def __init__( def __init__(
self, self,
...@@ -104,6 +154,7 @@ class ROCMAiterMLASparseMetadataBuilder( ...@@ -104,6 +154,7 @@ class ROCMAiterMLASparseMetadataBuilder(
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self.device = device self.device = device
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config) self.mla_dims = get_mla_dims(self.model_config)
...@@ -124,6 +175,23 @@ class ROCMAiterMLASparseMetadataBuilder( ...@@ -124,6 +175,23 @@ class ROCMAiterMLASparseMetadataBuilder(
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
self.qo_indptr = torch.arange(
0, max_num_batched_tokens + 1, dtype=torch.int32, device=device
)
self.paged_kv_last_page_len = torch.ones(
max_num_batched_tokens, dtype=torch.int32, device=device
)
# These two needs to be calculated in runtime,
# but we still needs to prepare the buffer
self.paged_kv_indices = torch.zeros(
[max_num_batched_tokens * self.topk_tokens],
dtype=torch.int32,
device=device,
)
self.paged_kv_indptr = torch.zeros(
[max_num_batched_tokens + 1], dtype=torch.int32, device=device
)
def build( def build(
self, self,
...@@ -142,7 +210,15 @@ class ROCMAiterMLASparseMetadataBuilder( ...@@ -142,7 +210,15 @@ class ROCMAiterMLASparseMetadataBuilder(
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
torch.from_numpy(req_id_per_token), non_blocking=True torch.from_numpy(req_id_per_token), non_blocking=True
) )
self.paged_kv_indices.fill_(0)
self.paged_kv_indptr.fill_(0)
req_id_per_token = self.req_id_per_token_buffer[:num_tokens] req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
qo_indptr = self.qo_indptr[: num_tokens + 1]
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_tokens]
paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens]
paged_kv_indptr = self.paged_kv_indptr[: num_tokens + 1]
paged_kv_indptr_rest = self.paged_kv_indptr[num_tokens + 1 :]
metadata = ROCMAiterMLASparseMetadata( metadata = ROCMAiterMLASparseMetadata(
num_reqs=common_attn_metadata.num_reqs, num_reqs=common_attn_metadata.num_reqs,
...@@ -155,6 +231,11 @@ class ROCMAiterMLASparseMetadataBuilder( ...@@ -155,6 +231,11 @@ class ROCMAiterMLASparseMetadataBuilder(
req_id_per_token=req_id_per_token, req_id_per_token=req_id_per_token,
block_size=self.kv_cache_spec.block_size, block_size=self.kv_cache_spec.block_size,
topk_tokens=self.topk_tokens, topk_tokens=self.topk_tokens,
qo_indptr=qo_indptr,
paged_kv_last_page_len=paged_kv_last_page_len,
paged_kv_indices=paged_kv_indices,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indptr_rest=paged_kv_indptr_rest,
) )
return metadata return metadata
...@@ -226,20 +307,39 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]): ...@@ -226,20 +307,39 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
def _forward_bf16_kv( def _forward_bf16_kv(
self, self,
q: torch.Tensor, q: torch.Tensor, # [sq, heads, d_qk]
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk]
topk_indices: torch.Tensor, topk_indices: torch.Tensor, # [sq, topk]
attn_metadata: ROCMAiterMLASparseMetadata, attn_metadata: ROCMAiterMLASparseMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = q.shape[0] num_tokens = q.shape[0]
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( output = torch.empty(
-1, 1, kv_c_and_k_pe_cache.shape[-1] [num_tokens, self.num_heads, self.kv_lora_rank],
dtype=q.dtype,
device=q.device,
)
seq_len = (topk_indices != -1).sum(dim=-1)
torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:])
attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1])
fetch_id_to_ragged_triton(
topk_indices,
attn_metadata.paged_kv_indptr,
attn_metadata.paged_kv_indices,
attn_metadata.topk_tokens,
)
rocm_aiter_ops.mla_decode_fwd(
q,
kv_c_and_k_pe_cache,
output,
self.scale,
attn_metadata.qo_indptr,
1,
attn_metadata.paged_kv_indptr,
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len,
) )
topk_indices = topk_indices.view(num_tokens, 1, -1)
output = reference_mla_sparse_prefill(
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale, 512
)[0]
return output[:, : self.num_heads, :] return output[:, : self.num_heads, :]
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment