Commit d71496bf authored by zhuwenwen's avatar zhuwenwen
Browse files

support dsa

parent 1ce0a9a2
......@@ -755,7 +755,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false,"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3" || KV_DTYPE == "fp8_ds_mla") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
......
......@@ -396,6 +396,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
getattr(layer, "_marlin_w16a16_moe_enabled", False)
......
......@@ -16,6 +16,7 @@ from vllm.v1.attention.backends.mla.indexer import (
)
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.v1.worker.workspace import current_workspace_manager
from lightop import op, gemmopt
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
......@@ -73,42 +74,57 @@ def sparse_attn_indexer(
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
ops.indexer_k_quant_and_cache(
k,
kv_cache,
slot_mapping,
quant_block_size,
scale_fmt,
)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
ops.indexer_k_quant_and_cache(
k,
kv_cache,
slot_mapping,
quant_block_size,
scale_fmt,
)
topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype),
((total_seq_lens, 4), torch.uint8),
)
for chunk in prefill_metadata.chunks:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype),
((total_seq_lens, 4), torch.uint8),
)
for chunk in prefill_metadata.chunks:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)
logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
else:
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k,
weights[chunk.token_start:chunk.token_end].to(torch.float32),
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0],
k.shape[0],
64,
128,
True,
)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[
......@@ -149,15 +165,27 @@ def sparse_attn_indexer(
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
logits = fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
)
if not current_platform.is_rocm():
logits = fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
)
else:
logits = gemmopt.paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[:num_padded_tokens].to(torch.float32),
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len,
)
num_rows = logits.shape[0]
......@@ -258,7 +286,8 @@ class SparseAttnIndexer(CustomOp):
if current_platform.is_cuda():
return self.forward_cuda(hidden_states, q_fp8, k, weights)
elif current_platform.is_rocm():
return self.forward_hip(hidden_states, q_fp8, k, weights)
# return self.forward_hip(hidden_states, q_fp8, k, weights)
return self.forward_cuda(hidden_states, q_fp8, k, weights)
else:
raise NotImplementedError(
"SparseAttnIndexer native forward is only implemented for "
......
......@@ -712,6 +712,8 @@ class Indexer(nn.Module):
)
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)
else:
q_fp8 = q
weights, _ = self.weights_proj(hidden_states)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
......
......@@ -261,15 +261,16 @@ class RocmPlatform(Platform):
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if attn_selector_config.use_sparse:
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
raise ValueError(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
)
assert block_size == 1, (
"Sparse MLA backend on ROCm only supports block size 1 for now."
)
# if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
# raise ValueError(
# "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
# )
# assert block_size == 1, (
# "Sparse MLA backend on ROCm only supports block size 1 for now."
# )
logger.info_once("Using Sparse MLA backend.")
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
# return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
return AttentionBackendEnum.FLASHMLA_SPARSE.get_path()
if attn_selector_config.use_mla:
# if attn_selector_config.use_sparse:
......
......@@ -27,6 +27,7 @@ logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend):
exclude_from_block_size_selection = True
@staticmethod
def get_name() -> str:
return "DEEPSEEK_V32_INDEXER"
......@@ -323,15 +324,15 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
if is_deep_gemm_supported():
if current_platform.is_rocm():
self.scheduler_metadata_buffer= gemmopt.get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
else:
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
# if is_deep_gemm_supported():
if current_platform.is_rocm():
self.scheduler_metadata_buffer= gemmopt.get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
else:
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
......
......@@ -31,7 +31,8 @@ else:
if current_platform.is_rocm():
import flash_mla.cuda as flash_mla_cuda
# import flash_mla.cuda as flash_mla_cuda
from flash_mla.flash_mla_interface import flash_mla_cuda
_flashmla_C_AVAILABLE = True
_flashmla_extension_C_AVAILABLE = True
......
......@@ -5537,6 +5537,10 @@ class GPUModelRunner(
ValueError: If no valid block size found
"""
#exclude indexer backend
def _participates_in_block_size_selection(backend: type[AttentionBackend]) -> bool:
return not getattr(backend, "exclude_from_block_size_selection", False)
def block_size_is_supported(
backends: list[type[AttentionBackend]], block_size: int
) -> bool:
......@@ -5558,7 +5562,12 @@ class GPUModelRunner(
return False
return True
backends = [group.backend for group in attn_groups]
all_backends = [group.backend for group in attn_groups]
backends = [
b for b in all_backends
if _participates_in_block_size_selection(b)
]
# Case 1: if the block_size of kv cache manager is supported by all backends,
# return it directly
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment