Unverified Commit dbebb7f8 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Perf] Reuse workspace for FP8+FP4 Marlin MoE (#20500)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
parent 3053a22b
...@@ -398,7 +398,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -398,7 +398,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
quant_type_id=scalar_types.float4_e2m1f.id, quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map) expert_map=expert_map,
workspace=layer.workspace)
# FlashInfer fused experts path # FlashInfer fused experts path
if self.fused_experts is not None: if self.fused_experts is not None:
...@@ -940,7 +941,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -940,7 +941,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
quant_type_id=scalar_types.float8_e4m3fn.id, quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map) expert_map=expert_map,
workspace=layer.workspace)
assert self.fused_experts_func is not None assert self.fused_experts_func is not None
......
...@@ -1103,7 +1103,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1103,7 +1103,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
quant_type_id=scalar_types.float8_e4m3fn.id, quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map) expert_map=expert_map,
workspace=layer.workspace)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert self.block_quant is None assert self.block_quant is None
assert (not renormalize and custom_routing_function is not None) assert (not renormalize and custom_routing_function is not None)
......
...@@ -1474,7 +1474,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1474,7 +1474,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
quant_type_id=scalar_types.float4_e2m1f.id, quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map) expert_map=expert_map,
workspace=layer.workspace)
if self.fused_experts is not None: if self.fused_experts is not None:
assert self.allow_flashinfer and \ assert self.allow_flashinfer and \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment