Commit a55b8f91 authored by zhuwenwen's avatar zhuwenwen
Browse files

replace the fp8_mqa_logits and fp8_paged_mqa_logits interfaces in deepgemm...

replace the fp8_mqa_logits and fp8_paged_mqa_logits interfaces in deepgemm with mqa_logits and paged_mqa_logits from lightop
parent 31021d81
...@@ -71,7 +71,6 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk ...@@ -71,7 +71,6 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv, direct_register_custom_op from vllm.utils import cdiv, direct_register_custom_op
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.v1.attention.backends.mla.indexer import (DeepseekV32IndexerBackend, from vllm.v1.attention.backends.mla.indexer import (DeepseekV32IndexerBackend,
DeepseekV32IndexerMetadata) DeepseekV32IndexerMetadata)
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
...@@ -83,6 +82,11 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter, ...@@ -83,6 +82,11 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
if current_platform.is_rocm():
from lightop import op, gemmopt
else:
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -623,7 +627,16 @@ def sparse_attn_indexer( ...@@ -623,7 +627,16 @@ def sparse_attn_indexer(
chunk.cu_seq_lens, chunk.cu_seq_lens,
chunk.num_reqs, chunk.num_reqs,
) )
logits = fp8_mqa_logits( if current_platform.is_rocm():
logits = fp8_mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
(k_fp8, k_scale),
weights[chunk.token_start:chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
else:
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end], q_fp8[chunk.token_start:chunk.token_end],
(k_fp8, k_scale), (k_fp8, k_scale),
weights[chunk.token_start:chunk.token_end], weights[chunk.token_start:chunk.token_end],
...@@ -667,15 +680,26 @@ def sparse_attn_indexer( ...@@ -667,15 +680,26 @@ def sparse_attn_indexer(
next_n = padded_q_fp8_decode_tokens.shape[1] next_n = padded_q_fp8_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0] assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n num_padded_tokens = batch_size * next_n
logits = fp8_paged_mqa_logits( if current_platform.is_rocm():
padded_q_fp8_decode_tokens, logits = fp8_paged_mqa_logits(
kv_cache, padded_q_fp8_decode_tokens,
weights[:num_padded_tokens], kv_cache,
decode_metadata.seq_lens, weights[:num_padded_tokens],
decode_metadata.block_table, decode_metadata.seq_lens,
decode_metadata.schedule_metadata, decode_metadata.block_table,
max_model_len=max_model_len, decode_metadata.schedule_metadata,
) max_model_len=max_model_len,
)
else:
logits = gemmopt.paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
)
# padded query len # padded query len
current_device = padded_q_fp8_decode_tokens.device current_device = padded_q_fp8_decode_tokens.device
padded_num_tokens = batch_size * next_n padded_num_tokens = batch_size * next_n
......
...@@ -14,6 +14,8 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport, ...@@ -14,6 +14,8 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
split_decodes_and_prefills) split_decodes_and_prefills)
from vllm.platforms import current_platform
from lightop import gemmopt
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -309,8 +311,12 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -309,8 +311,12 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
seq_lens = common_attn_metadata.seq_lens[:num_decodes] seq_lens = common_attn_metadata.seq_lens[:num_decodes]
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( if current_platform.is_rocm():
seq_lens, self.kv_cache_spec.block_size, self.num_sms) self.scheduler_metadata_buffer[:] = gemmopt.get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms)
else:
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms)
decode_metadata = DeepSeekV32IndexerDecodeMetadata( decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata. block_table=common_attn_metadata.
block_table_tensor[:num_decodes, ...], block_table_tensor[:num_decodes, ...],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment