Unverified Commit 42135d68 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)

parent e14467be
......@@ -128,11 +128,15 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
"""
Return True if the input module/layer could be processed with DeepGEMM.
"""
# FIXME: this logic is brittle and incorrect - since we
# could use DeepGEMM with for than just Fp8LinearMethod
block_size = get_mk_alignment_for_contiguous_layout()[0]
if not (
isinstance(module, LinearBase)
and isinstance(module.quant_method, Fp8LinearMethod)
and module.quant_method.block_quant
and not module.quant_method.use_marlin
):
return False
......
......@@ -29,7 +29,7 @@ from vllm.model_executor.layers.batch_invariant import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
kNvfp4Quant,
kNvfp4Dynamic,
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
......@@ -1184,7 +1184,7 @@ class FlashInferImpl(AttentionImpl):
return (
self.support_trtllm_attn
and self.kv_cache_dtype.startswith("fp8")
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
and quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
)
# FlashInfer requires attention sinks to be float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment