Unverified Commit 6d4e27ce authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix] Enforce DeepGEMM when using sparse_attn_indexer on CUDA (#34374)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 4c078fa5
...@@ -10,6 +10,7 @@ from vllm.logger import init_logger ...@@ -10,6 +10,7 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mla.indexer import ( from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadata, DeepseekV32IndexerMetadata,
...@@ -277,6 +278,10 @@ class SparseAttnIndexer(CustomOp): ...@@ -277,6 +278,10 @@ class SparseAttnIndexer(CustomOp):
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.max_total_seq_len = max_total_seq_len self.max_total_seq_len = max_total_seq_len
self.topk_indices_buffer = topk_indices_buffer self.topk_indices_buffer = topk_indices_buffer
if current_platform.is_cuda() and not has_deep_gemm():
raise RuntimeError(
"Sparse Attention Indexer CUDA op requires DeepGEMM to be installed."
)
def forward_native( def forward_native(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment