Unverified Commit e584dce5 authored by Wuxun Zhang's avatar Wuxun Zhang Committed by GitHub
Browse files

Add XPU MLA Sparse backend for DeepSeek v3.2 (#33230)


Signed-off-by: default avatarZhang, Wuxun <wuxun.zhang@intel.com>
parent 40c0461f
...@@ -214,3 +214,4 @@ configuration. ...@@ -214,3 +214,4 @@ configuration.
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | | `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
| `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any |
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.v1.attention.ops.xpu_mla_sparse import triton_bf16_mla_sparse_interface
# https://github.com/deepseek-ai/FlashMLA/blob/main/tests/ref.py#L7
def _merge_two_lse(
lse0: torch.Tensor, lse1: torch.Tensor | None, s_q: int, h_q: int
) -> torch.Tensor:
if lse1 is None:
return lse0
else:
return torch.logsumexp(
torch.stack([lse0.view(s_q, h_q), lse1.broadcast_to(s_q, h_q)], dim=0),
dim=0,
)
# Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/tests/ref.py#L19
def reference_mla_sparse_prefill(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int,
topk_length: torch.Tensor | None = None,
attn_sink: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
- o: [s_q, h_q, dv]
- o_fp32: [s_q, h_q, dv]
- max_logits: [s_q, h_q]
- lse: [s_q, h_q]
"""
s_q, h_q, d_qk = q.shape
s_kv, _, _ = kv.shape
_, _, topk = indices.shape
indices = indices.clone().squeeze(1)
if topk_length is not None:
mask = torch.arange(topk, device=topk_length.device).unsqueeze(0).broadcast_to(
s_q, topk
) >= topk_length.unsqueeze(1) # [s_q, topk]
indices[mask] = -1
invalid_mask = (indices < 0) | (indices >= s_kv) # [s_q, topk]
indices[invalid_mask] = 0
q = q.float()
gathered_kv = (
kv.index_select(dim=0, index=indices.flatten()).reshape(s_q, topk, d_qk).float()
) # [s_q, topk, d_qk]
P = q @ gathered_kv.transpose(1, 2) # [s_q, h_q, topk]
P *= sm_scale
P[invalid_mask.unsqueeze(1).broadcast_to(P.shape)] = float("-inf")
orig_lse = torch.logsumexp(P, dim=-1) # [s_q, h_q]
max_logits = P.max(dim=-1).values # [s_q, h_q]
lse_for_o = _merge_two_lse(orig_lse, attn_sink, s_q, h_q)
if not torch.is_inference_mode_enabled():
lse_for_o = lse_for_o.clone()
lse_for_o[lse_for_o == float("-inf")] = float(
"+inf"
) # So that corresponding O will be 0
s_for_o = torch.exp(P - lse_for_o.unsqueeze(-1))
out = s_for_o @ gathered_kv[..., :d_v] # [s_q, h_q, dv]
lonely_q_mask = orig_lse == float("-inf") # [s_q, h_q]
orig_lse[lonely_q_mask] = float("+inf")
return (out.to(kv.dtype), out, max_logits, orig_lse)
@pytest.mark.parametrize("device_str", ["xpu"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.skipif(
not torch.xpu.is_available(),
reason="XPU is required",
)
def test_bf16_triton_sparse_mla(device_str, dtype):
device = torch.device(device_str)
s_q = 1
s_kv = 256
h_q = 64 # kernel expects multiple of 64
h_kv = 1
d_qk = 576
d_v = 512
topk = 128
torch.random.manual_seed(1234)
q = torch.randn((s_q, h_q, d_qk), dtype=dtype, device=device)
kv = torch.randn((s_kv, h_kv, d_qk), dtype=dtype, device=device)
indices = torch.full((s_q, h_kv, topk), -1, dtype=torch.int32, device=device)
for t in range(s_q):
for h in range(h_kv):
i_i = torch.randperm(max(1, t))[:topk]
indices[t, h, : len(i_i)] = i_i
sm_scale = d_qk**-0.5
out, max_logits, lse = triton_bf16_mla_sparse_interface(
q, kv, indices, sm_scale, d_v
)
assert out.shape == (s_q, h_q, d_v)
assert max_logits.shape == (s_q, h_q)
assert lse.shape == (s_q, h_q)
ref_out, ref_out_fp32, ref_max_logits, ref_lse = reference_mla_sparse_prefill(
q, kv, indices, sm_scale, d_v
)
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2)
assert torch.allclose(max_logits, ref_max_logits, atol=1e-3, rtol=1e-3)
assert torch.allclose(lse, ref_lse, atol=1e-3, rtol=1e-3)
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -157,3 +158,247 @@ class xpu_ops: ...@@ -157,3 +158,247 @@ class xpu_ops:
"get_scheduler_metadata is not implemented for xpu_ops, returning None." "get_scheduler_metadata is not implemented for xpu_ops, returning None."
) )
return None return None
@staticmethod
def indexer_k_quant_and_cache(
k: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
quant_block_size: int,
scale_fmt: str | None,
) -> None:
head_dim = k.shape[-1]
k = k.view(-1, head_dim) # [total_tokens, head_dim]
def group_quant_torch(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype | None = None,
column_major_scales: bool = False,
out_q: torch.Tensor | None = None,
use_ue8m0: bool | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if use_ue8m0 is None:
# Default fallback - could import is_deep_gemm_e8m0_used if needed
use_ue8m0 = False
if dtype is None:
dtype = current_platform.fp8_dtype()
# Validate inputs
assert x.shape[-1] % group_size == 0, (
f"Last dimension {x.shape[-1]} must be divisible by "
f"group_size {group_size}"
)
assert x.stride(-1) == 1, "Input tensor groups must be contiguous"
# Prepare output tensor
if out_q is None:
x_q = torch.empty_like(x, dtype=dtype)
else:
assert out_q.shape == x.shape
x_q = out_q
# Reshape input for group processing
# Original shape: (..., last_dim)
# Target shape: (..., num_groups, group_size)
original_shape = x.shape
num_groups = original_shape[-1] // group_size
# Reshape to separate groups
group_shape = original_shape[:-1] + (num_groups, group_size)
x_grouped = x.view(group_shape)
# Compute per-group absolute maximum values
# Shape: (..., num_groups)
abs_max = torch.amax(torch.abs(x_grouped), dim=-1, keepdim=False)
abs_max = torch.maximum(
abs_max, torch.tensor(eps, device=x.device, dtype=x.dtype)
)
# Compute scales
FP8_MAX = torch.finfo(dtype).max
FP8_MIN = torch.finfo(dtype).min
scale_raw = abs_max / FP8_MAX
if use_ue8m0:
# For UE8M0 format, scales must be powers of 2
scales = torch.pow(2.0, torch.ceil(torch.log2(scale_raw)))
else:
scales = scale_raw
# Expand scales for broadcasting with grouped data
# Shape: (..., num_groups, 1)
scales_expanded = scales.unsqueeze(-1)
# Quantize the grouped data
x_scaled = x_grouped / scales_expanded
x_clamped = torch.clamp(x_scaled, FP8_MIN, FP8_MAX)
x_quantized = x_clamped.to(dtype)
# Reshape back to original shape
x_q.copy_(x_quantized.view(original_shape))
# Prepare scales tensor in requested format
if column_major_scales:
# Column-major: (num_groups,) + batch_dims
# Transpose the scales to put group dimension first
scales_shape = (num_groups,) + original_shape[:-1]
x_s = scales.permute(-1, *range(len(original_shape) - 1))
x_s = x_s.contiguous().view(scales_shape)
else:
# Row-major: batch_dims + (num_groups,)
x_s = scales.contiguous()
# Ensure scales are float32
return x_q, x_s.float()
k_fp8, k_scale = group_quant_torch(
k,
group_size=quant_block_size,
column_major_scales=False,
use_ue8m0=(scale_fmt == "ue8m0"),
)
k_fp8_bytes = k_fp8.view(-1, head_dim).view(torch.uint8)
scale_bytes = k_scale.view(torch.uint8).view(-1, 4)
k = torch.cat(
[k_fp8_bytes, scale_bytes], dim=-1
) # [total_tokens, head_dim + 4]
slot_mapping = slot_mapping.flatten()
# kv_cache: [num_block, block_size, head_dim + 4]
kv_cache.view(-1, kv_cache.shape[-1]).index_copy_(0, slot_mapping, k)
@staticmethod
def cp_gather_indexer_k_quant_cache(
kv_cache: torch.Tensor,
dst_k: torch.Tensor,
dst_scale: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
) -> None:
"""
Args:
kv_cache: [num_blocks, block_size, cache_stride] - quantized KV cache
Layout per block: [k_values, scale_values]
- k_values: [block_size * head_dim]
- scale_values: [block_size * head_dim * 4 / quant_block_size]
dst_k: [num_tokens, head_dim] - output tensor for K values
dst_scale: [num_tokens, head_dim / quant_block_size * 4]
- output tensor for scale values
block_table: [batch_size, num_blocks] - block table for indexing
cu_seq_lens: [batch_size + 1] - cumulative sequence lengths
"""
batch_size = block_table.size(0)
num_tokens = dst_k.size(0)
head_dim = dst_k.size(1)
cache_block_size = kv_cache.size(1)
quant_block_size = head_dim * 4 // dst_scale.size(1)
# For each token, find which batch it belongs to using searchsorted
token_indices = torch.arange(num_tokens, device=dst_k.device) + 1
# cu_seq_lens is [batch_size + 1], we need to find which interval each
# token belongs to
batch_indices = torch.searchsorted(cu_seq_lens, token_indices) - 1
batch_indices = torch.clamp(batch_indices, 0, batch_size - 1)
# Calculate the in-batch sequence index for each token
inbatch_seq_indices = token_indices - cu_seq_lens[batch_indices]
# Find which block each token belongs to
block_indices_in_table = inbatch_seq_indices // cache_block_size
physical_block_indices = block_table[batch_indices, block_indices_in_table]
# Calculate the offset within each block
inblock_offsets = (inbatch_seq_indices - 1) % cache_block_size
# Calculate strides
block_stride = kv_cache.stride(0) # stride for each block
# Flatten kv_cache for easier indexing
kv_cache_flat = kv_cache.view(-1)
# Calculate source offset for K values for all tokens (vectorized)
src_block_offsets = physical_block_indices * block_stride
src_k_offsets = src_block_offsets + inblock_offsets * head_dim
# Gather K values using advanced indexing
# Create indices for all elements we need to gather
k_indices = src_k_offsets.unsqueeze(1) + torch.arange(
head_dim, device=dst_k.device
)
dst_k[:] = kv_cache_flat[k_indices]
# Calculate source offset for scale values (vectorized)
# Scales are stored after all K values for each block
scale_size = head_dim * 4 // quant_block_size
src_scale_offsets = src_block_offsets + head_dim + inblock_offsets * scale_size
# Gather scale values
scale_indices = src_scale_offsets.unsqueeze(1) + torch.arange(
scale_size, device=dst_scale.device
)
dst_scale[:] = kv_cache_flat[scale_indices]
@staticmethod
def top_k_per_row_prefill(
logits: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
raw_topk_indices: torch.Tensor,
num_rows: int,
stride0: int,
strdide1: int,
topk_tokens: int,
) -> torch.Tensor:
real_topk = min(topk_tokens, logits.shape[-1])
topk_indices = logits.topk(real_topk, dim=-1)[1].to(torch.int32)
topk_indices -= cu_seqlen_ks[:, None]
mask_lo = topk_indices >= 0
mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0
mask = torch.full_like(
topk_indices, False, dtype=torch.bool, device=topk_indices.device
)
mask = mask_lo & mask_hi
topk_indices.masked_fill_(~mask, -1)
raw_topk_indices[: topk_indices.shape[0], : topk_indices.shape[1]] = (
topk_indices
)
@staticmethod
def top_k_per_row_decode(
logits: torch.Tensor,
next_n: int,
seq_lens: torch.Tensor,
raw_topk_indices: torch.Tensor,
num_rows: int,
stride0: int,
stride1: int,
topk_tokens: int,
) -> torch.Tensor:
device = logits.device
batch_size = seq_lens.size(0)
# padded query len
padded_num_tokens = batch_size * next_n
positions = (
torch.arange(logits.shape[-1], device=device)
.unsqueeze(0)
.expand(batch_size * next_n, -1)
)
row_indices = torch.arange(padded_num_tokens, device=device) // next_n
next_n_offset = torch.arange(padded_num_tokens, device=device) % next_n
index_end_pos = (seq_lens[row_indices] - next_n + next_n_offset).unsqueeze(1)
# index_end_pos: [B * N, 1]
mask = positions <= index_end_pos
# mask: [B * N, L]
logits = logits.masked_fill(~mask, float("-inf"))
topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K
topk_indices[topk_indices > index_end_pos] = -1
raw_topk_indices[: topk_indices.shape[0], : topk_indices.shape[1]] = (
topk_indices
)
...@@ -135,6 +135,19 @@ def sparse_attn_indexer( ...@@ -135,6 +135,19 @@ def sparse_attn_indexer(
topk_indices = topk_indices_buffer[ topk_indices = topk_indices_buffer[
chunk.token_start : chunk.token_end, :topk_tokens chunk.token_start : chunk.token_end, :topk_tokens
] ]
if current_platform.is_xpu():
ops.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
else:
torch.ops._C.top_k_per_row_prefill( torch.ops._C.top_k_per_row_prefill(
logits, logits,
chunk.cu_seqlen_ks, chunk.cu_seqlen_ks,
...@@ -219,6 +232,18 @@ def sparse_attn_indexer( ...@@ -219,6 +232,18 @@ def sparse_attn_indexer(
lengths, lengths,
None, None,
) )
else:
if current_platform.is_xpu():
ops.top_k_per_row_decode(
logits,
next_n,
decode_metadata.seq_lens,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
else: else:
torch.ops._C.top_k_per_row_decode( torch.ops._C.top_k_per_row_decode(
logits, logits,
...@@ -320,14 +345,14 @@ class SparseAttnIndexer(CustomOp): ...@@ -320,14 +345,14 @@ class SparseAttnIndexer(CustomOp):
k: torch.Tensor, k: torch.Tensor,
weights: torch.Tensor, weights: torch.Tensor,
): ):
if current_platform.is_cuda(): if current_platform.is_cuda() or current_platform.is_xpu():
return self.forward_cuda(hidden_states, q_fp8, k, weights) return self.forward_cuda(hidden_states, q_fp8, k, weights)
elif current_platform.is_rocm(): elif current_platform.is_rocm():
return self.forward_hip(hidden_states, q_fp8, k, weights) return self.forward_hip(hidden_states, q_fp8, k, weights)
else: else:
raise NotImplementedError( raise NotImplementedError(
"SparseAttnIndexer native forward is only implemented for " "SparseAttnIndexer native forward is only implemented for "
"CUDA and ROCm platform." "CUDA, ROCm and XPU platforms."
) )
def forward_cuda( def forward_cuda(
......
...@@ -61,7 +61,8 @@ class XPUPlatform(Platform): ...@@ -61,7 +61,8 @@ class XPUPlatform(Platform):
dtype = attn_selector_config.dtype dtype = attn_selector_config.dtype
if attn_selector_config.use_sparse: if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.") logger.info_once("Using XPU MLA Sparse backend.")
return AttentionBackendEnum.XPU_MLA_SPARSE.get_path()
if attn_selector_config.use_mla: if attn_selector_config.use_mla:
logger.info_once("Using Triton MLA backend on V1 engine.") logger.info_once("Using Triton MLA backend on V1 engine.")
return AttentionBackendEnum.TRITON_MLA.get_path() return AttentionBackendEnum.TRITON_MLA.get_path()
......
...@@ -17,4 +17,7 @@ else: ...@@ -17,4 +17,7 @@ else:
tl = TritonLanguagePlaceholder() tl = TritonLanguagePlaceholder()
tldevice = TritonLanguagePlaceholder() tldevice = TritonLanguagePlaceholder()
__all__ = ["HAS_TRITON", "triton", "tl", "tldevice"] LOG2E = 1.4426950408889634
LOGE2 = 0.6931471805599453
__all__ = ["HAS_TRITON", "triton", "tl", "tldevice", "LOG2E", "LOGE2"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional
import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
get_mla_dims,
)
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata,
SparseMLAAttentionImpl,
)
from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index,
)
from vllm.v1.attention.ops.xpu_mla_sparse import triton_bf16_mla_sparse_interface
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.model_executor.models.deepseek_v2 import Indexer
logger = init_logger(__name__)
class XPUMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
]
@staticmethod
def get_name() -> str:
return "XPU_MLA_SPARSE"
@staticmethod
def get_metadata_cls() -> type["XPUMLASparseMetadata"]:
return XPUMLASparseMetadata
@staticmethod
def get_builder_cls() -> type["XPUMLASparseMetadataBuilder"]:
return XPUMLASparseMetadataBuilder
@staticmethod
def get_impl_cls() -> type["XPUMLASparseImpl"]:
return XPUMLASparseImpl
@classmethod
def is_mla(cls) -> bool:
return True
@classmethod
def is_sparse(cls) -> bool:
return True
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@dataclass
class XPUMLASparseMetadata(AttentionMetadata):
num_reqs: int
max_query_len: int
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
block_table: torch.Tensor
req_id_per_token: torch.Tensor
block_size: int = 1
topk_tokens: int = 2048
@dataclass
class XPUMLASparseMetadataBuilder(AttentionMetadataBuilder[XPUMLASparseMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.device = device
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
self.topk_tokens_tensor = torch.tensor(
[self.topk_tokens], device=device, dtype=torch.int32
)
self.max_model_len_tensor = torch.tensor(
[self.model_config.max_model_len], device=device, dtype=torch.int32
)
# this is ignored by `flash_mla_with_kvcache` if indices not None
self.dummy_block_table = torch.empty(
(1, 1), dtype=torch.int32, device=self.device
)
self.req_id_per_token_buffer = torch.empty(
(max_num_batched_tokens,),
dtype=torch.int32,
device=device,
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> XPUMLASparseMetadata:
num_tokens = common_attn_metadata.num_actual_tokens
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
seg_lengths = np.diff(starts)
req_id_per_token = np.repeat(
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
)
# Zero-fill for cudagraphs
self.req_id_per_token_buffer.fill_(0)
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
torch.from_numpy(req_id_per_token), non_blocking=True
)
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
metadata = XPUMLASparseMetadata(
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
block_table=common_attn_metadata.block_table_tensor,
req_id_per_token=req_id_per_token,
block_size=self.kv_cache_spec.block_size,
topk_tokens=self.topk_tokens,
)
return metadata
class XPUMLASparseImpl(SparseMLAAttentionImpl[XPUMLASparseMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
topk_indice_buffer: torch.Tensor | None = None,
indexer: Optional["Indexer"] = None,
**mla_args,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
self.softmax_scale = scale
assert indexer is not None
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
def _forward_bf16_kv(
self,
q: torch.Tensor, # [sq, heads, d_qk]
kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk]
topk_indices: torch.Tensor, # [sq, topk]
attn_metadata: XPUMLASparseMetadata,
) -> torch.Tensor:
num_tokens = q.shape[0]
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
-1, 1, kv_c_and_k_pe_cache.shape[-1]
)
topk_indices = topk_indices.view(num_tokens, 1, -1)
output, _, _ = triton_bf16_mla_sparse_interface(
q,
kv_c_and_k_pe_cache,
topk_indices,
sm_scale=self.softmax_scale,
)
return output[:, : self.num_heads, :]
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: XPUMLASparseMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 kv is not supported with XPU MLA Sparse yet")
# Concatenate q if it's a tuple (ql_nope, q_pe)
if isinstance(q, tuple):
q = torch.cat(q, dim=-1)
num_actual_toks = q.shape[0]
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks]
topk_indices_global = triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
)
attn_out = self._forward_bf16_kv(
q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata
)
return attn_out, None
...@@ -57,6 +57,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): ...@@ -57,6 +57,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
ROCM_AITER_MLA_SPARSE = ( ROCM_AITER_MLA_SPARSE = (
"vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend" "vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend"
) )
XPU_MLA_SPARSE = "vllm.v1.attention.backends.mla.xpu_mla_sparse.XPUMLASparseBackend"
TORCH_SDPA = "" # this tag is only used for ViT TORCH_SDPA = "" # this tag is only used for ViT
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
FLASHINFER_MLA = ( FLASHINFER_MLA = (
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import LOG2E, LOGE2, tl, triton
@triton.jit
def _bf16_mla_sparse_kernel(
q_buffer,
k_buffer,
v_buffer,
indices_ptr,
out_ptr,
softmax_lse_ptr,
max_logits_ptr,
seq_q,
seq_kv,
h_q,
dim_qk,
dim_v,
stride_q_token,
stride_q_head,
stride_k_token,
stride_k_head,
stride_v_token,
stride_v_head,
stride_out_token,
stride_out_head,
stride_lse,
stride_indices_token,
stride_indices_head,
sm_scale,
kv_group_num: tl.constexpr,
index_topk: tl.constexpr,
BLOCK_H: tl.constexpr, # block size for num heads
BLOCK_M: tl.constexpr, # block size for num tokens
BLOCK_N: tl.constexpr, # block size for indices
BLOCK_DV: tl.constexpr, # block size for dim_v
BLOCK_DMODEL: tl.constexpr, # block size for dim_nope
BLOCK_DPE: tl.constexpr, # block size for positional embedding
LOGE2: tl.constexpr,
):
cur_q = tl.program_id(0)
cur_head_id = tl.program_id(1)
cur_kv_head_id = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
VALID_BLOCK_H: tl.constexpr = BLOCK_H if kv_group_num > BLOCK_H else kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < h_q)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
off_q = cur_q * stride_q_token + cur_head[:, None] * stride_q_head + offs_d[None, :]
mask_dmodel = offs_d < BLOCK_DMODEL
q = tl.load(
q_buffer + off_q, mask=(mask_h[:, None]) & (mask_dmodel[None, :]), other=0.0
)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
off_qpe = (
cur_q * stride_q_token
+ cur_head[:, None] * stride_q_head
+ offs_dpe[None, :]
)
# assume dim_qk == BLOCK_DMODEL + BLOCK_DPE
mask_dpe = offs_dpe < dim_qk
qpe = tl.load(
q_buffer + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
for start_indice in range(0, index_topk, BLOCK_N):
offs_indice = start_indice + tl.arange(0, BLOCK_N)
mask_indice = offs_indice < index_topk
indices = tl.load(
indices_ptr
+ (
cur_q * stride_indices_token
+ cur_kv_head_id * stride_indices_head
+ offs_indice
),
mask=mask_indice,
other=-1,
)
mask_kv = (indices >= 0) & (indices < seq_kv)
mask_kv_d = mask_dmodel
offs_k = (
indices[None, :] * stride_k_token
+ cur_kv_head_id * stride_k_head
+ offs_d[:, None]
)
# q_nope @ k_nope
k = tl.load(
k_buffer + offs_k, mask=(mask_kv[None, :]) & (mask_kv_d[:, None]), other=0.0
)
qk = tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
# q_rope @ k_rope
offs_kpe = (
indices[None, :] * stride_k_token
+ cur_kv_head_id * stride_k_head
+ offs_dpe[:, None]
)
mask_k_dpe = offs_dpe < dim_qk
kpe = tl.load(
k_buffer + offs_kpe,
mask=(mask_kv[None, :]) & (mask_k_dpe[:, None]),
other=0.0,
)
qk += tl.dot(qpe, kpe.to(q.dtype))
# apply scaling
qk *= sm_scale
qk = tl.where((mask_h[:, None]) & (mask_kv[None, :]), qk, -float("inf"))
# load v
mask_v_d = offs_dv < dim_v
offs_v = (
indices[:, None] * stride_v_token
+ cur_kv_head_id * stride_v_head
+ offs_dv[None, :]
)
v = tl.load(
v_buffer + offs_v, mask=(mask_kv[:, None]) & (mask_v_d[None, :]), other=0.0
)
# online softmax
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp2(e_max - n_e_max)
p = tl.exp2(qk - n_e_max[:, None])
acc *= re_scale[:, None]
# score @ v
acc += tl.dot(p.to(v.dtype), v)
# update global sum and max
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
# rescaling
acc /= e_sum[:, None]
max_logits = e_max * LOGE2
# calculate lse
lse = max_logits + tl.log2(e_sum) * LOGE2
# write output
offs_o = (
cur_q * stride_out_token
+ cur_head[:, None] * stride_out_head
+ offs_dv[None, :]
)
mask_out_d = offs_dv < dim_v
tl.store(
out_ptr + offs_o,
acc.to(tl.bfloat16),
mask=(mask_h[:, None]) & (mask_out_d[None, :]),
)
offs_lse = cur_q * stride_lse + cur_head
tl.store(softmax_lse_ptr + offs_lse, lse, mask=mask_h)
tl.store(max_logits_ptr + offs_lse, max_logits, mask=mask_h)
# reference implementation of bf16 sparse prefill kernel
def triton_bf16_mla_sparse_interface(
q: torch.Tensor, # [num_tokens, num_heads_q, dim_qk]
kv: torch.Tensor, # [num_tokens, num_heads_kv, dim_qk]
indices: torch.Tensor, # [num_tokens, num_heads_kv, topk]
sm_scale: float,
d_v: int = 512,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
out : [num_tokens, num_heads_q, d_v]
max_logits : [num_tokens, num_heads_q]
lse : logsumexp, [num_tokens, num_heads_q]
"""
num_tokens, num_heads_q, dim_qk = q.shape
_, num_heads_kv, _ = kv.shape
assert dim_qk == kv.shape[2], "q and kv have different head dimensions"
# for deepseek v3.2, index topk should be 2048
_, _, index_topk = indices.shape
BLOCK_H = 16
BLOCK_DMODEL = 512
BLOCK_DPE = 64
BLOCK_M = 32
BLOCK_N = 16
BLOCK_DV = 512
assert d_v == BLOCK_DV, "only support d_v = 512"
assert dim_qk == BLOCK_DMODEL + BLOCK_DPE, (
"dim_qk does not match BLOCK_DMODEL + BLOCK_DPE"
)
assert num_heads_kv == 1, "only support kv head = 1 for now"
assert index_topk % BLOCK_N == 0, "index_topk must be multiple of BLOCK_N"
sm_scale *= LOG2E
kv_group_num = num_heads_q // num_heads_kv
grid = (
num_tokens,
triton.cdiv(num_heads_q, min(BLOCK_H, kv_group_num)),
)
out = torch.zeros((num_tokens, num_heads_q, d_v), dtype=q.dtype, device=q.device)
softmax_lse = torch.zeros(
(num_tokens, num_heads_q), dtype=torch.float32, device=q.device
)
max_logits = torch.zeros(
(num_tokens, num_heads_q), dtype=torch.float32, device=q.device
)
k = kv
v = kv[..., :d_v]
_bf16_mla_sparse_kernel[grid](
q_buffer=q,
k_buffer=k,
v_buffer=v,
indices_ptr=indices,
out_ptr=out,
softmax_lse_ptr=softmax_lse,
max_logits_ptr=max_logits,
seq_q=num_tokens,
seq_kv=kv.shape[0],
h_q=num_heads_q,
dim_qk=dim_qk,
dim_v=d_v,
stride_q_token=q.stride(0),
stride_q_head=q.stride(1),
stride_k_token=k.stride(0),
stride_k_head=k.stride(1),
stride_v_token=v.stride(0),
stride_v_head=v.stride(1),
stride_out_token=out.stride(0),
stride_out_head=out.stride(1),
stride_lse=softmax_lse.stride(0),
stride_indices_token=indices.stride(0),
stride_indices_head=indices.stride(1),
sm_scale=sm_scale,
kv_group_num=kv_group_num,
index_topk=index_topk,
BLOCK_H=BLOCK_H,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DV=BLOCK_DV,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
LOGE2=LOGE2,
)
return out, max_logits, softmax_lse
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment