Commit d71496bf authored by zhuwenwen's avatar zhuwenwen
Browse files

support dsa

parent 1ce0a9a2
...@@ -755,7 +755,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -755,7 +755,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false,"Unsupported input type of kv cache: ", SRC_DTYPE); \ TORCH_CHECK(false,"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} else { \ } else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3" || KV_DTYPE == "fp8_ds_mla") { \
if (SRC_DTYPE == at::ScalarType::Float) { \ if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \ } else if (SRC_DTYPE == at::ScalarType::Half) { \
......
...@@ -396,6 +396,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -396,6 +396,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if ( if (
getattr(layer, "_marlin_w16a16_moe_enabled", False) getattr(layer, "_marlin_w16a16_moe_enabled", False)
......
...@@ -16,6 +16,7 @@ from vllm.v1.attention.backends.mla.indexer import ( ...@@ -16,6 +16,7 @@ from vllm.v1.attention.backends.mla.indexer import (
) )
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.v1.worker.workspace import current_workspace_manager from vllm.v1.worker.workspace import current_workspace_manager
from lightop import op, gemmopt
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -73,6 +74,7 @@ def sparse_attn_indexer( ...@@ -73,6 +74,7 @@ def sparse_attn_indexer(
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
ops.indexer_k_quant_and_cache( ops.indexer_k_quant_and_cache(
k, k,
kv_cache, kv_cache,
...@@ -86,6 +88,7 @@ def sparse_attn_indexer( ...@@ -86,6 +88,7 @@ def sparse_attn_indexer(
prefill_metadata = attn_metadata.prefill prefill_metadata = attn_metadata.prefill
# Get the full shared workspace buffers once (will allocate on first use) # Get the full shared workspace buffers once (will allocate on first use)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
workspace_manager = current_workspace_manager() workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous( k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype), ((total_seq_lens, head_dim), fp8_dtype),
...@@ -109,6 +112,19 @@ def sparse_attn_indexer( ...@@ -109,6 +112,19 @@ def sparse_attn_indexer(
chunk.cu_seqlen_ks, chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke, chunk.cu_seqlen_ke,
) )
else:
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k,
weights[chunk.token_start:chunk.token_end].to(torch.float32),
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0],
k.shape[0],
64,
128,
True,
)
num_rows = logits.shape[0] num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[ topk_indices = topk_indices_buffer[
...@@ -149,6 +165,7 @@ def sparse_attn_indexer( ...@@ -149,6 +165,7 @@ def sparse_attn_indexer(
assert batch_size == decode_metadata.seq_lens.shape[0] assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n num_padded_tokens = batch_size * next_n
if not current_platform.is_rocm():
logits = fp8_paged_mqa_logits( logits = fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens, padded_q_fp8_decode_tokens,
kv_cache, kv_cache,
...@@ -158,6 +175,17 @@ def sparse_attn_indexer( ...@@ -158,6 +175,17 @@ def sparse_attn_indexer(
decode_metadata.schedule_metadata, decode_metadata.schedule_metadata,
max_model_len=max_model_len, max_model_len=max_model_len,
) )
else:
logits = gemmopt.paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[:num_padded_tokens].to(torch.float32),
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len,
)
num_rows = logits.shape[0] num_rows = logits.shape[0]
...@@ -258,7 +286,8 @@ class SparseAttnIndexer(CustomOp): ...@@ -258,7 +286,8 @@ class SparseAttnIndexer(CustomOp):
if current_platform.is_cuda(): if current_platform.is_cuda():
return self.forward_cuda(hidden_states, q_fp8, k, weights) return self.forward_cuda(hidden_states, q_fp8, k, weights)
elif current_platform.is_rocm(): elif current_platform.is_rocm():
return self.forward_hip(hidden_states, q_fp8, k, weights) # return self.forward_hip(hidden_states, q_fp8, k, weights)
return self.forward_cuda(hidden_states, q_fp8, k, weights)
else: else:
raise NotImplementedError( raise NotImplementedError(
"SparseAttnIndexer native forward is only implemented for " "SparseAttnIndexer native forward is only implemented for "
......
...@@ -712,6 +712,8 @@ class Indexer(nn.Module): ...@@ -712,6 +712,8 @@ class Indexer(nn.Module):
) )
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1) q_scale = q_scale.view(-1, self.n_head, 1)
else:
q_fp8 = q
weights, _ = self.weights_proj(hidden_states) weights, _ = self.weights_proj(hidden_states)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938": if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
......
...@@ -261,15 +261,16 @@ class RocmPlatform(Platform): ...@@ -261,15 +261,16 @@ class RocmPlatform(Platform):
kv_cache_dtype = attn_selector_config.kv_cache_dtype kv_cache_dtype = attn_selector_config.kv_cache_dtype
if attn_selector_config.use_sparse: if attn_selector_config.use_sparse:
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"): # if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
raise ValueError( # raise ValueError(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." # "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
) # )
assert block_size == 1, ( # assert block_size == 1, (
"Sparse MLA backend on ROCm only supports block size 1 for now." # "Sparse MLA backend on ROCm only supports block size 1 for now."
) # )
logger.info_once("Using Sparse MLA backend.") logger.info_once("Using Sparse MLA backend.")
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() # return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
return AttentionBackendEnum.FLASHMLA_SPARSE.get_path()
if attn_selector_config.use_mla: if attn_selector_config.use_mla:
# if attn_selector_config.use_sparse: # if attn_selector_config.use_sparse:
......
...@@ -27,6 +27,7 @@ logger = init_logger(__name__) ...@@ -27,6 +27,7 @@ logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend): class DeepseekV32IndexerBackend(AttentionBackend):
exclude_from_block_size_selection = True
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "DEEPSEEK_V32_INDEXER" return "DEEPSEEK_V32_INDEXER"
...@@ -323,7 +324,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -323,7 +324,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
seq_lens = common_attn_metadata.seq_lens[:num_decodes] seq_lens = common_attn_metadata.seq_lens[:num_decodes]
if is_deep_gemm_supported(): # if is_deep_gemm_supported():
if current_platform.is_rocm(): if current_platform.is_rocm():
self.scheduler_metadata_buffer= gemmopt.get_paged_mqa_logits_metadata( self.scheduler_metadata_buffer= gemmopt.get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms seq_lens, self.kv_cache_spec.block_size, self.num_sms
......
...@@ -31,7 +31,8 @@ else: ...@@ -31,7 +31,8 @@ else:
if current_platform.is_rocm(): if current_platform.is_rocm():
import flash_mla.cuda as flash_mla_cuda # import flash_mla.cuda as flash_mla_cuda
from flash_mla.flash_mla_interface import flash_mla_cuda
_flashmla_C_AVAILABLE = True _flashmla_C_AVAILABLE = True
_flashmla_extension_C_AVAILABLE = True _flashmla_extension_C_AVAILABLE = True
......
...@@ -5537,6 +5537,10 @@ class GPUModelRunner( ...@@ -5537,6 +5537,10 @@ class GPUModelRunner(
ValueError: If no valid block size found ValueError: If no valid block size found
""" """
#exclude indexer backend
def _participates_in_block_size_selection(backend: type[AttentionBackend]) -> bool:
return not getattr(backend, "exclude_from_block_size_selection", False)
def block_size_is_supported( def block_size_is_supported(
backends: list[type[AttentionBackend]], block_size: int backends: list[type[AttentionBackend]], block_size: int
) -> bool: ) -> bool:
...@@ -5558,7 +5562,12 @@ class GPUModelRunner( ...@@ -5558,7 +5562,12 @@ class GPUModelRunner(
return False return False
return True return True
backends = [group.backend for group in attn_groups] all_backends = [group.backend for group in attn_groups]
backends = [
b for b in all_backends
if _participates_in_block_size_selection(b)
]
# Case 1: if the block_size of kv cache manager is supported by all backends, # Case 1: if the block_size of kv cache manager is supported by all backends,
# return it directly # return it directly
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment