Unverified Commit 03ee4811 authored by amirkl94's avatar amirkl94 Committed by GitHub
Browse files

Feature: Support Relu2 in FusedMoE fp8 cutlass path (#27261)

parent 5a87076d
...@@ -77,10 +77,14 @@ class TestData: ...@@ -77,10 +77,14 @@ class TestData:
@staticmethod @staticmethod
def make_moe_tensors_8bit( def make_moe_tensors_8bit(
m: int, k: int, n: int, e: int, reorder: bool m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu"
) -> "TestData": ) -> "TestData":
is_gated = activation != "relu2_no_mul"
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) w13 = torch.randn(
(e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
)
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
# Scale to fp8 # Scale to fp8
...@@ -190,18 +194,22 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( ...@@ -190,18 +194,22 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
@pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"])
def test_flashinfer_cutlass_moe_fp8_no_graph( def test_flashinfer_cutlass_moe_fp8_no_graph(
m: int, m: int,
n: int, n: int,
k: int, k: int,
e: int, e: int,
topk: int, topk: int,
activation: str,
monkeypatch, monkeypatch,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False) td = TestData.make_moe_tensors_8bit(
m, k, n, e, reorder=False, activation=activation
)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = FusedMoE.select_experts(
...@@ -233,7 +241,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ...@@ -233,7 +241,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=False, inplace=False,
activation="silu", activation=activation,
global_num_experts=e, global_num_experts=e,
expert_map=None, expert_map=None,
apply_router_weight_on_input=True, apply_router_weight_on_input=True,
...@@ -253,7 +261,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ...@@ -253,7 +261,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td.layer, td.layer,
topk_weights, topk_weights,
topk_ids, topk_ids,
activation="silu", activation=activation,
global_num_experts=e, global_num_experts=e,
expert_map=None, expert_map=None,
apply_router_weight_on_input=True, apply_router_weight_on_input=True,
......
...@@ -148,8 +148,14 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -148,8 +148,14 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool | None, apply_router_weight_on_input: bool | None,
): ):
assert activation == "silu", ( from flashinfer.fused_moe.core import ActivationType
"Only activation silu is supported in FlashInferExperts"
activation_str_to_value_map = {
"silu": ActivationType.Swiglu, # This is the default
"relu2_no_mul": ActivationType.Relu2,
}
assert activation in activation_str_to_value_map, (
f"{activation=} missing from {activation_str_to_value_map.keys()=}"
) )
# Select quantization metadata based on FP8 format/path # Select quantization metadata based on FP8 format/path
...@@ -215,6 +221,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -215,6 +221,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
ep_size=self.ep_size, ep_size=self.ep_size,
ep_rank=self.ep_rank, ep_rank=self.ep_rank,
output=output, output=output,
activation_type=activation_str_to_value_map[activation],
# Informs FlashInfer to use the block-scale decoding path when True # Informs FlashInfer to use the block-scale decoding path when True
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale, use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
) )
......
...@@ -354,12 +354,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -354,12 +354,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self.cutlass_fp8_supported = cutlass_fp8_supported() self.cutlass_fp8_supported = cutlass_fp8_supported()
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
if ( if (
envs.VLLM_USE_FLASHINFER_MOE_FP8 self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
and has_flashinfer_moe() and not self.moe.is_act_and_mul
and self.moe.is_act_and_mul
): ):
self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once(
"Non-gated MoE is not supported for min-latency mode,"
"falling back to high-throughput mode"
)
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
logger.info_once( logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
) )
...@@ -557,10 +563,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -557,10 +563,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
) )
if self.flashinfer_moe_backend is not None: if self.flashinfer_moe_backend is not None:
if self.moe.is_act_and_mul:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
register_moe_scaling_factors(layer)
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
register_moe_scaling_factors(layer)
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
...@@ -570,13 +577,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -570,13 +577,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
return fp8_w8a8_moe_quant_config( return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(), g1_alphas=layer.output1_scales_gate_scalar.squeeze(),
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(), g2_alphas=layer.output2_scales_scalar.squeeze(),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a1_gscale=layer.w13_input_scale, a1_gscale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
a2_gscale=1.0 / layer.w2_input_scale, a2_gscale=layer.w2_input_scale_inv,
per_act_token_quant=False, per_act_token_quant=False,
) )
...@@ -642,9 +649,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -642,9 +649,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
) )
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert not renormalize assert activation in ("silu", "relu2_no_mul"), (
assert activation == "silu", ( "Expected activation to be in ('silu', 'relu2_no_mul'),"
f"Expected 'silu' activation but got {activation}" f"but got {activation}"
) )
return flashinfer_cutlass_moe_fp8( return flashinfer_cutlass_moe_fp8(
x, x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment