Commit a55b8f91 authored by zhuwenwen's avatar zhuwenwen
Browse files

replace the fp8_mqa_logits and fp8_paged_mqa_logits interfaces in deepgemm...

replace the fp8_mqa_logits and fp8_paged_mqa_logits interfaces in deepgemm with mqa_logits and paged_mqa_logits from lightop
parent 31021d81
......@@ -71,7 +71,6 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv, direct_register_custom_op
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.v1.attention.backends.mla.indexer import (DeepseekV32IndexerBackend,
DeepseekV32IndexerMetadata)
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
......@@ -83,6 +82,11 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON
if current_platform.is_rocm():
from lightop import op, gemmopt
else:
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
......@@ -623,6 +627,7 @@ def sparse_attn_indexer(
chunk.cu_seq_lens,
chunk.num_reqs,
)
if current_platform.is_rocm():
logits = fp8_mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
(k_fp8, k_scale),
......@@ -630,6 +635,14 @@ def sparse_attn_indexer(
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
else:
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
(k_fp8, k_scale),
weights[chunk.token_start:chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
dim=-1)[1]
topk_indices -= chunk.cu_seqlen_ks[:, None]
......@@ -667,6 +680,7 @@ def sparse_attn_indexer(
next_n = padded_q_fp8_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
if current_platform.is_rocm():
logits = fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
......@@ -676,6 +690,16 @@ def sparse_attn_indexer(
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
)
else:
logits = gemmopt.paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
)
# padded query len
current_device = padded_q_fp8_decode_tokens.device
padded_num_tokens = batch_size * next_n
......
......@@ -14,6 +14,8 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.platforms import current_platform
from lightop import gemmopt
logger = init_logger(__name__)
......@@ -309,6 +311,10 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
if current_platform.is_rocm():
self.scheduler_metadata_buffer[:] = gemmopt.get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms)
else:
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms)
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment