Unverified Commit bcdfb2a3 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix] Fix incorrect dispatch for CutlassBlockScaledGroupedGemm and DeepGEMM (#20933)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent ba8c3000
...@@ -488,11 +488,16 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -488,11 +488,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.warning_once("Failed to import DeepGemm kernels.") logger.warning_once("Failed to import DeepGemm kernels.")
elif not self.block_quant: elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using " logger.warning_once("Model is not block quantized. Not using "
" DeepGemm kernels") "DeepGemm kernels")
elif (current_platform.is_cuda() elif (current_platform.is_cuda()
and current_platform.has_device_capability(90)): and current_platform.is_device_capability(90)):
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
self.allow_deep_gemm = True self.allow_deep_gemm = True
elif (current_platform.is_cuda()
and is_blackwell_deep_gemm_used()):
logger.info_once("Using DeepGemm SM100 kernels for "
"Fp8MoEMethod.")
self.allow_deep_gemm = True
else: else:
logger.warning_once( logger.warning_once(
"DeepGemm not supported on the current platform.") "DeepGemm not supported on the current platform.")
...@@ -500,10 +505,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -500,10 +505,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Check for CutlassBlockScaledGroupedGemm support. # Check for CutlassBlockScaledGroupedGemm support.
self.allow_cutlass_block_scaled_grouped_gemm = False self.allow_cutlass_block_scaled_grouped_gemm = False
if not self.block_quant: if not self.block_quant:
logger.warning_once("Model is not block quantized. Not using " logger.debug_once("Model is not block quantized. Not using "
"CutlassBlockScaledGroupedGemm kernels") "CutlassBlockScaledGroupedGemm kernels")
elif (current_platform.is_cuda() elif (current_platform.is_cuda()
and current_platform.has_device_capability(100)): and current_platform.is_device_capability(100)):
logger.info_once( logger.info_once(
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod." "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment