Commit d71496bf authored by zhuwenwen's avatar zhuwenwen
Browse files

support dsa

parent 1ce0a9a2
...@@ -755,7 +755,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -755,7 +755,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false,"Unsupported input type of kv cache: ", SRC_DTYPE); \ TORCH_CHECK(false,"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} else { \ } else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3" || KV_DTYPE == "fp8_ds_mla") { \
if (SRC_DTYPE == at::ScalarType::Float) { \ if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \ } else if (SRC_DTYPE == at::ScalarType::Half) { \
......
...@@ -396,6 +396,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -396,6 +396,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if ( if (
getattr(layer, "_marlin_w16a16_moe_enabled", False) getattr(layer, "_marlin_w16a16_moe_enabled", False)
......
...@@ -16,6 +16,7 @@ from vllm.v1.attention.backends.mla.indexer import ( ...@@ -16,6 +16,7 @@ from vllm.v1.attention.backends.mla.indexer import (
) )
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.v1.worker.workspace import current_workspace_manager from vllm.v1.worker.workspace import current_workspace_manager
from lightop import op, gemmopt
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -73,42 +74,57 @@ def sparse_attn_indexer( ...@@ -73,42 +74,57 @@ def sparse_attn_indexer(
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
ops.indexer_k_quant_and_cache( if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
k, ops.indexer_k_quant_and_cache(
kv_cache, k,
slot_mapping, kv_cache,
quant_block_size, slot_mapping,
scale_fmt, quant_block_size,
) scale_fmt,
)
topk_indices_buffer[: hidden_states.shape[0]] = -1 topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill: if has_prefill:
prefill_metadata = attn_metadata.prefill prefill_metadata = attn_metadata.prefill
# Get the full shared workspace buffers once (will allocate on first use) # Get the full shared workspace buffers once (will allocate on first use)
workspace_manager = current_workspace_manager() if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous( workspace_manager = current_workspace_manager()
((total_seq_lens, head_dim), fp8_dtype), k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, 4), torch.uint8), ((total_seq_lens, head_dim), fp8_dtype),
) ((total_seq_lens, 4), torch.uint8),
for chunk in prefill_metadata.chunks:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
) )
for chunk in prefill_metadata.chunks:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)
logits = fp8_mqa_logits( logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end], q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()), (k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end], weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks, chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke, chunk.cu_seqlen_ke,
) )
else:
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k,
weights[chunk.token_start:chunk.token_end].to(torch.float32),
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0],
k.shape[0],
64,
128,
True,
)
num_rows = logits.shape[0] num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[ topk_indices = topk_indices_buffer[
...@@ -149,15 +165,27 @@ def sparse_attn_indexer( ...@@ -149,15 +165,27 @@ def sparse_attn_indexer(
assert batch_size == decode_metadata.seq_lens.shape[0] assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n num_padded_tokens = batch_size * next_n
logits = fp8_paged_mqa_logits( if not current_platform.is_rocm():
padded_q_fp8_decode_tokens, logits = fp8_paged_mqa_logits(
kv_cache, padded_q_fp8_decode_tokens,
weights[:num_padded_tokens], kv_cache,
decode_metadata.seq_lens, weights[:num_padded_tokens],
decode_metadata.block_table, decode_metadata.seq_lens,
decode_metadata.schedule_metadata, decode_metadata.block_table,
max_model_len=max_model_len, decode_metadata.schedule_metadata,
) max_model_len=max_model_len,
)
else:
logits = gemmopt.paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[:num_padded_tokens].to(torch.float32),
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len,
)
num_rows = logits.shape[0] num_rows = logits.shape[0]
...@@ -258,7 +286,8 @@ class SparseAttnIndexer(CustomOp): ...@@ -258,7 +286,8 @@ class SparseAttnIndexer(CustomOp):
if current_platform.is_cuda(): if current_platform.is_cuda():
return self.forward_cuda(hidden_states, q_fp8, k, weights) return self.forward_cuda(hidden_states, q_fp8, k, weights)
elif current_platform.is_rocm(): elif current_platform.is_rocm():
return self.forward_hip(hidden_states, q_fp8, k, weights) # return self.forward_hip(hidden_states, q_fp8, k, weights)
return self.forward_cuda(hidden_states, q_fp8, k, weights)
else: else:
raise NotImplementedError( raise NotImplementedError(
"SparseAttnIndexer native forward is only implemented for " "SparseAttnIndexer native forward is only implemented for "
......
...@@ -712,6 +712,8 @@ class Indexer(nn.Module): ...@@ -712,6 +712,8 @@ class Indexer(nn.Module):
) )
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1) q_scale = q_scale.view(-1, self.n_head, 1)
else:
q_fp8 = q
weights, _ = self.weights_proj(hidden_states) weights, _ = self.weights_proj(hidden_states)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938": if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
......
...@@ -261,15 +261,16 @@ class RocmPlatform(Platform): ...@@ -261,15 +261,16 @@ class RocmPlatform(Platform):
kv_cache_dtype = attn_selector_config.kv_cache_dtype kv_cache_dtype = attn_selector_config.kv_cache_dtype
if attn_selector_config.use_sparse: if attn_selector_config.use_sparse:
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"): # if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
raise ValueError( # raise ValueError(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." # "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
) # )
assert block_size == 1, ( # assert block_size == 1, (
"Sparse MLA backend on ROCm only supports block size 1 for now." # "Sparse MLA backend on ROCm only supports block size 1 for now."
) # )
logger.info_once("Using Sparse MLA backend.") logger.info_once("Using Sparse MLA backend.")
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() # return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
return AttentionBackendEnum.FLASHMLA_SPARSE.get_path()
if attn_selector_config.use_mla: if attn_selector_config.use_mla:
# if attn_selector_config.use_sparse: # if attn_selector_config.use_sparse:
......
...@@ -27,6 +27,7 @@ logger = init_logger(__name__) ...@@ -27,6 +27,7 @@ logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend): class DeepseekV32IndexerBackend(AttentionBackend):
exclude_from_block_size_selection = True
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "DEEPSEEK_V32_INDEXER" return "DEEPSEEK_V32_INDEXER"
...@@ -323,15 +324,15 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -323,15 +324,15 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
seq_lens = common_attn_metadata.seq_lens[:num_decodes] seq_lens = common_attn_metadata.seq_lens[:num_decodes]
if is_deep_gemm_supported(): # if is_deep_gemm_supported():
if current_platform.is_rocm(): if current_platform.is_rocm():
self.scheduler_metadata_buffer= gemmopt.get_paged_mqa_logits_metadata( self.scheduler_metadata_buffer= gemmopt.get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms seq_lens, self.kv_cache_spec.block_size, self.num_sms
) )
else: else:
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms seq_lens, self.kv_cache_spec.block_size, self.num_sms
) )
decode_metadata = DeepSeekV32IndexerDecodeMetadata( decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...], block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
seq_lens=common_attn_metadata.seq_lens[:num_decodes], seq_lens=common_attn_metadata.seq_lens[:num_decodes],
......
...@@ -31,7 +31,8 @@ else: ...@@ -31,7 +31,8 @@ else:
if current_platform.is_rocm(): if current_platform.is_rocm():
import flash_mla.cuda as flash_mla_cuda # import flash_mla.cuda as flash_mla_cuda
from flash_mla.flash_mla_interface import flash_mla_cuda
_flashmla_C_AVAILABLE = True _flashmla_C_AVAILABLE = True
_flashmla_extension_C_AVAILABLE = True _flashmla_extension_C_AVAILABLE = True
......
...@@ -5537,6 +5537,10 @@ class GPUModelRunner( ...@@ -5537,6 +5537,10 @@ class GPUModelRunner(
ValueError: If no valid block size found ValueError: If no valid block size found
""" """
#exclude indexer backend
def _participates_in_block_size_selection(backend: type[AttentionBackend]) -> bool:
return not getattr(backend, "exclude_from_block_size_selection", False)
def block_size_is_supported( def block_size_is_supported(
backends: list[type[AttentionBackend]], block_size: int backends: list[type[AttentionBackend]], block_size: int
) -> bool: ) -> bool:
...@@ -5558,7 +5562,12 @@ class GPUModelRunner( ...@@ -5558,7 +5562,12 @@ class GPUModelRunner(
return False return False
return True return True
backends = [group.backend for group in attn_groups] all_backends = [group.backend for group in attn_groups]
backends = [
b for b in all_backends
if _participates_in_block_size_selection(b)
]
# Case 1: if the block_size of kv cache manager is supported by all backends, # Case 1: if the block_size of kv cache manager is supported by all backends,
# return it directly # return it directly
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment