Unverified Commit e5f599d4 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix] Disable shared expert overlap if Marlin MoE is used (#28410)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 28534b92
...@@ -678,6 +678,10 @@ class FusedMoE(CustomOp): ...@@ -678,6 +678,10 @@ class FusedMoE(CustomOp):
and self.moe_config.use_flashinfer_cutlass_kernels and self.moe_config.use_flashinfer_cutlass_kernels
) )
@property
def use_marlin_kernels(self):
return getattr(self.quant_method, "use_marlin", False)
@property @property
def use_dp_chunking(self) -> bool: def use_dp_chunking(self) -> bool:
return ( return (
......
...@@ -28,17 +28,17 @@ class SharedFusedMoE(FusedMoE): ...@@ -28,17 +28,17 @@ class SharedFusedMoE(FusedMoE):
super().__init__(**kwargs) super().__init__(**kwargs)
self._shared_experts = shared_experts self._shared_experts = shared_experts
# Disable shared expert overlap if we are using eplb, because of # Disable shared expert overlap if:
# correctness issues, or if using flashinfer with DP, since there # - we are using eplb, because of correctness issues
# is nothing to be gained in this case. Disabling the overlap # - we are using flashinfer with DP, since there nothint to gain
# optimization also prevents the shared experts from being hidden # - we are using marlin kjernels
# from torch.compile.
self.use_overlapped = ( self.use_overlapped = (
use_overlapped use_overlapped
and not ( and not (
# TODO(wentao): find the root cause and remove this condition # TODO(wentao): find the root cause and remove this condition
self.enable_eplb self.enable_eplb
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
or self.use_marlin_kernels
) )
and self._shared_experts is not None and self._shared_experts is not None
) )
......
...@@ -424,6 +424,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -424,6 +424,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
if self.quant_config.weight_bits != 4: if self.quant_config.weight_bits != 4:
raise ValueError("AWQMoEMethod only supports 4bit now.") raise ValueError("AWQMoEMethod only supports 4bit now.")
self.quant_type = scalar_types.uint4 self.quant_type = scalar_types.uint4
self.use_marlin = True
def create_weights( def create_weights(
self, self,
......
...@@ -1342,6 +1342,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1342,6 +1342,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
f"{WNA16_SUPPORTED_BITS}", f"{WNA16_SUPPORTED_BITS}",
) )
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]
self.use_marlin = True
def create_weights( def create_weights(
self, self,
......
...@@ -482,6 +482,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -482,6 +482,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
self.quant_type = scalar_types.uint8b128 self.quant_type = scalar_types.uint8b128
else: else:
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.") raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
self.use_marlin = True
def create_weights( def create_weights(
self, self,
......
...@@ -216,6 +216,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -216,6 +216,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig):
super().__init__(moe) super().__init__(moe)
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN
self.max_capture_size = ( self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size get_current_vllm_config().compilation_config.max_cudagraph_capture_size
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment