"vllm/vscode:/vscode.git/clone" did not exist on "11fcf0e0661365f24bfff9591434a0cec640df6c"
Unverified Commit 5dcd7ef1 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor][15/N] Apply Refactor to Fp8 (#31415)

parent ffc0a279
...@@ -11,12 +11,17 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -11,12 +11,17 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, apply_fi_trtllm_fp8_per_tensor_moe,
flashinfer_cutlass_moe_fp8,
register_scales_for_trtllm_fp8_per_tensor_moe, register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_flashinfer_fp8_moe_weights, rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
swap_w13_to_w31, swap_w13_to_w31,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8 from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
...@@ -103,6 +108,7 @@ class TestData: ...@@ -103,6 +108,7 @@ class TestData:
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2) w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
layer = torch.nn.Module() layer = torch.nn.Module()
layer.orig_dtype = torch.bfloat16
layer.w13_weight = w13_quantized.clone() layer.w13_weight = w13_quantized.clone()
layer.w2_weight = w2_quantized.clone() layer.w2_weight = w2_quantized.clone()
layer.w13_input_scale = a1_scale layer.w13_input_scale = a1_scale
...@@ -115,10 +121,10 @@ class TestData: ...@@ -115,10 +121,10 @@ class TestData:
pcp_size=1, pcp_size=1,
dp_size=1, dp_size=1,
ep_size=1, ep_size=1,
tp_rank=1, tp_rank=0,
pcp_rank=1, pcp_rank=0,
dp_rank=1, dp_rank=0,
ep_rank=1, ep_rank=0,
use_ep=False, use_ep=False,
all2all_backend="naive", all2all_backend="naive",
) )
...@@ -126,7 +132,9 @@ class TestData: ...@@ -126,7 +132,9 @@ class TestData:
# flashinfer expects swapped rows for w13 # flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if is_trtllm: if is_trtllm:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
layer.w13_weight, layer.w2_weight
)
register_scales_for_trtllm_fp8_per_tensor_moe( register_scales_for_trtllm_fp8_per_tensor_moe(
layer, layer,
layer.w13_weight_scale, layer.w13_weight_scale,
...@@ -199,7 +207,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( ...@@ -199,7 +207,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
quant_config=quant_config, quant_config=quant_config,
) )
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8( flashinfer_output = apply_fi_trtllm_fp8_per_tensor_moe(
layer=td.layer, layer=td.layer,
hidden_states=td.hidden_states, hidden_states=td.hidden_states,
router_logits=score, router_logits=score,
...@@ -277,17 +285,34 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ...@@ -277,17 +285,34 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
td.layer.quant_method = td.layer td.layer.quant_method = td.layer
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8( kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
defer_input_quant=quant_config.is_block_quantized
),
FlashInferExperts(
out_dtype=td.layer.orig_dtype,
quant_config=quant_config,
ep_rank=td.layer.moe_parallel_config.ep_rank,
ep_size=td.layer.moe_parallel_config.ep_size,
tp_rank=td.layer.moe_parallel_config.tp_rank,
tp_size=td.layer.moe_parallel_config.tp_size,
use_dp=False,
use_deepseek_fp8_block_scale=False,
),
)
flashinfer_cutlass_output = kernel(
td.hidden_states, td.hidden_states,
td.layer, td.layer.w13_weight,
td.layer.w2_weight,
topk_weights, topk_weights,
topk_ids, topk_ids,
inplace=False,
activation=activation, activation=activation,
global_num_experts=e, global_num_experts=e,
expert_map=None, expert_map=None,
apply_router_weight_on_input=True, apply_router_weight_on_input=True,
) )
torch.testing.assert_close( torch.testing.assert_close(
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2 output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
) )
...@@ -15,7 +15,6 @@ from vllm.model_executor.layers.quantization.fp8 import ( ...@@ -15,7 +15,6 @@ from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config, Fp8Config,
Fp8KVCacheMethod, Fp8KVCacheMethod,
Fp8LinearMethod, Fp8LinearMethod,
Fp8MoeBackend,
Fp8MoEMethod, Fp8MoEMethod,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -278,8 +277,18 @@ def test_scaled_fp8_quant(dtype) -> None: ...@@ -278,8 +277,18 @@ def test_scaled_fp8_quant(dtype) -> None:
# this is the case for marlin as well as per-tensor Fp8MoEMethod # this is the case for marlin as well as per-tensor Fp8MoEMethod
@pytest.mark.parametrize("use_marlin", [False]) # skip True @pytest.mark.parametrize("use_marlin", [False]) # skip True
def test_fp8_reloading( def test_fp8_reloading(
method_cls, is_checkpoint_fp8_serialized, weight_block_size, use_marlin, dist_init method_cls,
is_checkpoint_fp8_serialized,
weight_block_size,
use_marlin,
dist_init,
monkeypatch,
): ):
# NOTE(rob): this test fails when using DeepGEMM because the
# shapes are invalid. Previously the test was passing because
# we set fp8_backend to None, which sidestepped the issue.
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "0")
if is_checkpoint_fp8_serialized is False: if is_checkpoint_fp8_serialized is False:
pytest.skip("FP8 weight reloading does not support online quantization") pytest.skip("FP8 weight reloading does not support online quantization")
...@@ -307,6 +316,7 @@ def test_fp8_reloading( ...@@ -307,6 +316,7 @@ def test_fp8_reloading(
params_dtype=torch.bfloat16, params_dtype=torch.bfloat16,
weight_loader=default_weight_loader, weight_loader=default_weight_loader,
) )
method.use_marlin = use_marlin
else: else:
layer = FusedMoE( layer = FusedMoE(
...@@ -325,11 +335,6 @@ def test_fp8_reloading( ...@@ -325,11 +335,6 @@ def test_fp8_reloading(
weight_loader=default_weight_loader, weight_loader=default_weight_loader,
) )
# Fp8LinearMethod uses use_marlin
# Fp8MoEMethod uses fp8_backend
method.use_marlin = use_marlin
method.fp8_backend = Fp8MoeBackend.MARLIN if use_marlin else None
# capture weights format during loading # capture weights format during loading
original_metadata = [ original_metadata = [
(name, param.shape, getattr(param, "weight_loader", default_weight_loader)) (name, param.shape, getattr(param, "weight_loader", default_weight_loader))
......
...@@ -73,7 +73,6 @@ if HAS_TRITON: ...@@ -73,7 +73,6 @@ if HAS_TRITON:
CutlassExpertsFp8, CutlassExpertsFp8,
CutlassExpertsW4A8Fp8, CutlassExpertsW4A8Fp8,
cutlass_moe_fp4, cutlass_moe_fp4,
cutlass_moe_fp8,
cutlass_moe_w4a8_fp8, cutlass_moe_w4a8_fp8,
) )
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
...@@ -96,7 +95,6 @@ if HAS_TRITON: ...@@ -96,7 +95,6 @@ if HAS_TRITON:
"fused_experts", "fused_experts",
"get_config_file_name", "get_config_file_name",
"GroupedTopk", "GroupedTopk",
"cutlass_moe_fp8",
"cutlass_moe_fp4", "cutlass_moe_fp4",
"cutlass_moe_w4a8_fp8", "cutlass_moe_w4a8_fp8",
"CutlassExpertsFp8", "CutlassExpertsFp8",
......
...@@ -249,20 +249,28 @@ def run_cutlass_moe_fp8( ...@@ -249,20 +249,28 @@ def run_cutlass_moe_fp8(
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
e: int,
n: int,
k: int,
out_dtype: torch.dtype | None, out_dtype: torch.dtype | None,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
device: torch.dtype,
): ):
assert quant_config.use_fp8_w8a8 assert quant_config.use_fp8_w8a8
super().__init__(quant_config) super().__init__(quant_config)
# E: num_experts
# N: intermediate size per partition
# K: hidden dim
ab_strides1_c_strides2 = torch.full((e,), k, device=device, dtype=torch.int64)
ab_strides2 = torch.full((e,), n, device=device, dtype=torch.int64)
c_strides1 = torch.full((e,), 2 * n, device=device, dtype=torch.int64)
self.out_dtype = out_dtype self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1 self.ab_strides1 = ab_strides1_c_strides2
self.ab_strides2 = ab_strides2 self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1 self.c_strides1 = c_strides1
self.c_strides2 = c_strides2 self.c_strides2 = ab_strides1_c_strides2
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl. # Let PrepareAndFinalize::finalize() decide the impl.
...@@ -329,24 +337,6 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -329,24 +337,6 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
class CutlassExpertsFp8(CutlassExpertsFp8Base): class CutlassExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
out_dtype: torch.dtype | None,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
out_dtype,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
quant_config,
)
@property @property
def activation_formats( def activation_formats(
self, self,
...@@ -390,21 +380,10 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): ...@@ -390,21 +380,10 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
self, self,
max_experts_per_worker: int, max_experts_per_worker: int,
num_dispatchers: int, num_dispatchers: int,
out_dtype: torch.dtype | None, *args,
ab_strides1: torch.Tensor, **kwargs,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
): ):
super().__init__( super().__init__(*args, **kwargs)
out_dtype,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
quant_config,
)
assert max_experts_per_worker > 0 assert max_experts_per_worker > 0
self.max_experts_per_worker = max_experts_per_worker self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers self.num_dispatchers = num_dispatchers
...@@ -445,113 +424,6 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): ...@@ -445,113 +424,6 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
return (workspace1, workspace2, output) return (workspace1, workspace2, output)
def cutlass_moe_fp8(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
activation: str = "silu",
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mappings.
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
is -1, it means that this Rank is not responsible for global
expert-id i.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
- global_num_experts (int): The total number of experts.
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
assert quant_config is not None
if quant_config.a1_scale is not None:
assert quant_config.per_act_token_quant == (quant_config.a1_scale.numel() != 1)
if quant_config.a2_scale is not None:
assert quant_config.per_act_token_quant == (quant_config.a2_scale.numel() != 1)
if quant_config.w1_scale is not None:
if quant_config.per_out_ch_quant:
assert quant_config.w1_scale.dim() > 1 and quant_config.w1_scale.size(
1
) == w1_q.size(1)
else:
assert (
quant_config.w1_scale.dim() == 1 or quant_config.w1_scale.size(1) == 1
)
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=a.dtype,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
quant_config=quant_config,
),
)
return fn(
a,
w1_q,
w2_q,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
"""Base class for runtime dispatching of expert implementations."""
def __init__(
self,
experts: mk.FusedMoEPermuteExpertsUnpermute,
fallback_experts: mk.FusedMoEPermuteExpertsUnpermute,
):
super().__init__(experts.quant_config)
self.fallback_experts = fallback_experts
self.experts = experts
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
assert (
self.fallback_experts.activation_formats == self.experts.activation_formats
)
return self.fallback_experts.activation_formats
def supports_chunking(self) -> bool:
assert (
self.experts.supports_chunking()
== self.fallback_experts.supports_chunking()
)
return (
self.experts.supports_chunking()
and self.fallback_experts.supports_chunking()
)
def supports_expert_map(self) -> bool:
assert (
self.experts.supports_expert_map()
== self.fallback_experts.supports_expert_map()
)
return (
self.experts.supports_expert_map()
and self.fallback_experts.supports_expert_map()
)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
e_war = self.experts.finalize_weight_and_reduce_impl()
fbe_war = self.fallback_experts.finalize_weight_and_reduce_impl()
is_dge_war = e_war is not None
is_fbe_war = fbe_war is not None
if is_dge_war and is_fbe_war:
assert e_war == fbe_war, (
"Both implementations should agree on WeightAndReduce impls. "
f"Got e_war: {e_war}, and fbe_war: {fbe_war}"
)
if e_war is not None:
return e_war
assert fbe_war is not None
return fbe_war
@abstractmethod
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
raise NotImplementedError
@abstractmethod
def _select_experts_impl(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute:
raise NotImplementedError
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
experts = self._select_experts_impl(hidden_states, w1, w2)
experts.apply(
output,
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
activation,
global_num_experts,
expert_map,
a1q_scale,
a2_scale,
workspace13,
workspace2,
expert_tokens_meta,
apply_router_weight_on_input,
)
...@@ -100,7 +100,7 @@ direct_register_custom_op( ...@@ -100,7 +100,7 @@ direct_register_custom_op(
) )
def flashinfer_fused_moe_per_tensor_scale_fp8( def fi_trtllm_fp8_per_tensor_moe(
routing_logits: torch.Tensor, routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None, routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -158,7 +158,7 @@ def flashinfer_fused_moe_per_tensor_scale_fp8( ...@@ -158,7 +158,7 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
) )
def flashinfer_fused_moe_per_tensor_scale_fp8_fake( def fi_trtllm_fp8_per_tensor_moe_fake(
routing_logits: torch.Tensor, routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None, routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -184,9 +184,9 @@ def flashinfer_fused_moe_per_tensor_scale_fp8_fake( ...@@ -184,9 +184,9 @@ def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
# TODO(bnell): Does this really need to be a torch.op? # TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op( direct_register_custom_op(
op_name="flashinfer_fused_moe_per_tensor_scale_fp8", op_name="fi_trtllm_fp8_per_tensor_moe",
op_func=flashinfer_fused_moe_per_tensor_scale_fp8, op_func=fi_trtllm_fp8_per_tensor_moe,
mutates_args=["hidden_states"], mutates_args=["hidden_states"],
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake, fake_impl=fi_trtllm_fp8_per_tensor_moe_fake,
tags=(torch.Tag.needs_fixed_stride_order,), tags=(torch.Tag.needs_fixed_stride_order,),
) )
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
fp8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
make_fp8_moe_alpha_scales_for_fi,
prepare_fp8_moe_layer_for_fi,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
prepare_fp8_moe_layer_for_deepgemm,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_group_gemm_supported,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer_moe
from vllm.utils.import_utils import has_deep_gemm
logger = init_logger(__name__)
class Fp8MoeBackend(Enum):
NONE = 0
FLASHINFER_TRTLLM = 1
FLASHINFER_CUTLASS = 2
DEEPGEMM = 3
MARLIN = 4
TRITON = 5
AITER = 6
VLLM_CUTLASS = 7
def select_fp8_moe_backend(
block_quant: bool,
tp_size: int,
with_lora_support: bool,
is_act_and_mul: bool = True,
allow_vllm_cutlass: bool = False,
) -> Fp8MoeBackend:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# TODO(rob): in a future PR, we will query each mk for
# supported features and return the mk directly, just like
# we do for the Attention Backend.
if with_lora_support:
return Fp8MoeBackend.TRITON
def _make_log_backend(backend_name: str):
return f"Using {backend_name} backend for FP8 MoE"
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
if (
current_platform.is_cuda()
and (
current_platform.is_device_capability_family(100)
or current_platform.is_device_capability(90)
)
and envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
):
backend = get_flashinfer_moe_backend()
if backend == FlashinferMoeBackend.TENSORRT_LLM:
logger.info_once(_make_log_backend("FlashInfer TRTLLM"))
if not is_act_and_mul:
raise ValueError(
"FlashInfer TRTLLM FP8 MoE backend only supports "
"act_and_mul gate_up_project fusion. Please set "
"VLLM_USE_FLASHINFER_MOE_FP8=throughput to use the "
"FlashInfer CUTLASS backend instead."
)
return Fp8MoeBackend.FLASHINFER_TRTLLM
else:
if block_quant and current_platform.is_device_capability_family(100):
raise ValueError(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization on SM100. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency to use the "
"FlashInfer TRTLLM backend instead."
)
logger.info_once(_make_log_backend("FlashInfer CUTLASS"))
return Fp8MoeBackend.FLASHINFER_CUTLASS
# weight-only path for older GPUs without native FP8
if (
current_platform.is_cuda() and not current_platform.has_device_capability(89)
) or envs.VLLM_TEST_FORCE_FP8_MARLIN:
logger.info_once(_make_log_backend("Marlin"), scope="local")
return Fp8MoeBackend.MARLIN
# Determine if we should use DeepGEMM with block-quantized weights:
# - If explicitly set by user, respect their choice
# - If not explicitly set (default), disable when TP size is >= 8
moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and tp_size >= 8:
moe_use_deep_gemm = False
logger.info_once(
"DeepGEMM MoE is disabled by default when TP size is >= 8. "
"Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
scope="local",
)
use_deep_gemm = envs.VLLM_USE_DEEP_GEMM
if not is_deep_gemm_supported():
use_deep_gemm = False
logger.info_once(
"DeepGEMM is disabled because the platform does not support it.",
scope="local",
)
if use_deep_gemm and moe_use_deep_gemm and block_quant:
if not has_deep_gemm():
logger.warning_once(
"DeepGEMM backend requested but not available.", scope="local"
)
elif is_deep_gemm_supported():
logger.info_once(_make_log_backend("DeepGEMM"), scope="local")
return Fp8MoeBackend.DEEPGEMM
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
logger.info_once(_make_log_backend("ROCm AITER"), scope="local")
return Fp8MoeBackend.AITER
if allow_vllm_cutlass and not block_quant and cutlass_group_gemm_supported():
logger.info_once(_make_log_backend("vLLM CUTLASS"), scope="local")
return Fp8MoeBackend.VLLM_CUTLASS
# default to Triton
logger.info_once(_make_log_backend("Triton"), scope="local")
return Fp8MoeBackend.TRITON
def convert_to_fp8_moe_kernel_format(
fp8_backend: Fp8MoeBackend,
layer: torch.nn.Module,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_input_scale: torch.Tensor | None,
w2_input_scale: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
block_quant = hasattr(layer, "weight_block_size")
if fp8_backend == Fp8MoeBackend.DEEPGEMM:
assert block_quant
w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_deepgemm(
w13,
w2,
w13_scale,
w2_scale,
tuple(layer.weight_block_size),
)
elif fp8_backend == Fp8MoeBackend.AITER:
w13, w2 = rocm_aiter_ops.shuffle_weights(w13, w2)
elif fp8_backend == Fp8MoeBackend.MARLIN:
w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_marlin(
layer,
w13,
w2,
w13_scale,
w2_scale,
)
elif fp8_backend in [
Fp8MoeBackend.FLASHINFER_CUTLASS,
Fp8MoeBackend.FLASHINFER_TRTLLM,
]:
w13, w2, w13_scale = prepare_fp8_moe_layer_for_fi(
layer=layer,
w13=w13,
w2=w2,
w13_scale=w13_scale,
w13_input_scale=w13_input_scale,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
is_trtllm=(fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM),
)
return w13, w2, w13_scale, w2_scale
def make_fp8_moe_quant_config(
fp8_backend: Fp8MoeBackend,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig | None:
"""
Create FusedMoEQuantConfig for the specifed FP8 Backend.
The FusedMoEQuantConfig holds the scales that are used
at runtime by the Modular Kernel abstraction.
Note that certain kernels (e.g. Flashinfer CUTLASS) need
special Quant configs to handle non-standard inputs to
their kernel interfaces.
In a future PR, we will have this function should be
a method of the modular kernel itself.
"""
# TRTLLM does not use Modular Kernel abstraction yet.
if fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
# MARLIN is mixed precision W8A16 config.
if fp8_backend == Fp8MoeBackend.MARLIN:
return fp8_w8a16_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
block_shape=block_shape,
)
# Flashinfer CUTLASS per-tensor uses single dq scale
# (alpha = w_scale * a_scale) and inverse a2 scale.
if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and block_shape is None:
assert a1_scale is not None and a2_scale is not None
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w1_scale,
a1_scale,
w2_scale,
a2_scale,
)
return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
a1_gscale=(1.0 / a1_scale),
a2_gscale=(1.0 / a2_scale),
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
)
# All other backends use normal config.
return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
def make_fp8_moe_kernel(
layer: torch.nn.Module,
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
fp8_backend: Fp8MoeBackend,
) -> tuple[mk.FusedMoEModularKernel, bool]:
# Delayed import is required since the oracle is imported
# by CPU backends which cannot import all of these experts.
# TODO: update the experts to make this not happen.
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
# NOTE(rob): this is a WIP refactor. We are first migrating
# all of the kernels in the TP case to use mk. Once this is
# done, then we will initialzie the TP case and DP/EP case
# via the same code path (i.e. via maybe_init_modular_kernel).
# NOTE(rob): in progress migrating all into this format.
use_inplace = True
if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
defer_input_quant=moe_quant_config.is_block_quantized
),
FlashInferExperts(
out_dtype=layer.orig_dtype,
quant_config=moe_quant_config,
ep_rank=moe_config.ep_rank,
ep_size=moe_config.ep_size,
tp_rank=moe_config.tp_rank,
tp_size=moe_config.tp_size,
use_dp=(moe_config.dp_size > 1),
use_deepseek_fp8_block_scale=moe_quant_config.is_block_quantized,
),
)
use_inplace = False
elif fp8_backend == Fp8MoeBackend.AITER:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
kernel = mk.FusedMoEModularKernel(
# TODO: make defer_input_quant an attr of the AiterExperts
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
AiterExperts(quant_config=moe_quant_config),
)
elif fp8_backend == Fp8MoeBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config=moe_quant_config),
)
elif fp8_backend == Fp8MoeBackend.VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import (
TritonOrCutlassExperts,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonOrCutlassExperts(
out_dtype=moe_config.in_dtype,
e=layer.local_num_experts,
n=layer.intermediate_size_per_partition,
k=layer.hidden_size,
device=layer.w13_weight.device,
quant_config=moe_quant_config,
),
)
elif fp8_backend == Fp8MoeBackend.DEEPGEMM:
from vllm.model_executor.layers.fused_moe import (
TritonOrDeepGemmExperts,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonOrDeepGemmExperts(quant_config=moe_quant_config),
)
else:
from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts,
)
assert fp8_backend == Fp8MoeBackend.TRITON
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(quant_config=moe_quant_config),
)
return kernel, use_inplace
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.platforms import current_platform
class TritonOrCutlassExperts(FallbackExperts):
"""Cutlass with fallback to Triton for low latency shapes on SM100."""
def __init__(
self,
e: int,
n: int,
k: int,
out_dtype: torch.dtype | None,
quant_config: FusedMoEQuantConfig,
device: torch.dtype,
):
self.is_sm100 = current_platform.has_device_capability(100)
super().__init__(
experts=CutlassExpertsFp8(e, n, k, out_dtype, quant_config, device),
fallback_experts=TritonExperts(quant_config),
)
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Small batch fallback for sm100.
if self.is_sm100 and M <= 8:
return self.fallback_experts.workspace_shapes(
M,
N,
K,
topk,
global_num_experts,
local_num_experts,
expert_tokens_meta,
)
else:
return self.experts.workspace_shapes(
M,
N,
K,
topk,
global_num_experts,
local_num_experts,
expert_tokens_meta,
)
def _select_experts_impl(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute:
# Small batch fallback for sm100.
if self.is_sm100 and hidden_states.shape[0] <= 8:
return self.fallback_experts
else:
return self.experts
...@@ -10,77 +10,21 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( ...@@ -10,77 +10,21 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm, _valid_deep_gemm,
_valid_deep_gemm_shape, _valid_deep_gemm_shape,
) )
from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
) )
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): class TritonOrDeepGemmExperts(FallbackExperts):
def __init__( """DeepGemm with fallback to Triton for low latency shapes."""
self,
quant_config: FusedMoEQuantConfig,
allow_deep_gemm: bool = False,
):
super().__init__(quant_config)
self.triton_expert = TritonExperts(quant_config)
self.allow_deep_gemm = (
allow_deep_gemm
and self.quant_config.use_fp8_w8a8
and self.block_shape == get_mk_alignment_for_contiguous_layout()
)
self.deep_gemm_expert = ( def __init__(self, quant_config: FusedMoEQuantConfig):
DeepGemmExperts(self.quant_config) if self.allow_deep_gemm else None super().__init__(
) experts=DeepGemmExperts(quant_config),
fallback_experts=TritonExperts(quant_config),
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
assert (
self.deep_gemm_expert is None
or self.triton_expert.activation_formats
== self.deep_gemm_expert.activation_formats
) )
return self.triton_expert.activation_formats
def supports_chunking(self) -> bool:
dge = self.deep_gemm_expert
te = self.triton_expert
return (dge is None or dge.supports_chunking()) and (
te is None or te.supports_chunking()
)
def supports_expert_map(self) -> bool:
dge = self.deep_gemm_expert
te = self.triton_expert
return (dge is None or dge.supports_expert_map()) and (
te is None or te.supports_expert_map()
)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
dge = self.deep_gemm_expert
te = self.triton_expert
dge_war = dge.finalize_weight_and_reduce_impl() if dge else None
te_war = te.finalize_weight_and_reduce_impl() if te else None
is_dge_war = dge_war is not None
is_te_war = te_war is not None
if is_dge_war and is_te_war:
assert dge_war == te_war, (
"Both implementations should agree on WeightAndReduce impls. "
f"Got dge_war: {dge_war}, and te_war: {te_war}"
)
if dge_war is not None:
return dge_war
assert te_war is not None
return te_war
def workspace_shapes( def workspace_shapes(
self, self,
...@@ -95,11 +39,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -95,11 +39,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note: the deep gemm workspaces are strictly larger than the triton # Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm # workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set. # even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and ( if is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K):
is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K) return self.experts.workspace_shapes(
):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes(
M, M,
N, N,
K, K,
...@@ -109,7 +50,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -109,7 +50,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta, expert_tokens_meta,
) )
else: else:
return self.triton_expert.workspace_shapes( return self.fallback_experts.workspace_shapes(
M, M,
N, N,
K, K,
...@@ -119,45 +60,13 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -119,45 +60,13 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta, expert_tokens_meta,
) )
def apply( def _select_experts_impl(
self, self,
output: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, ) -> mk.FusedMoEPermuteExpertsUnpermute:
topk_ids: torch.Tensor, if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2):
activation: str, return self.experts
global_num_experts: int, else:
expert_map: torch.Tensor | None, return self.fallback_experts
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
use_deep_gemm = self.allow_deep_gemm and (
is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2)
)
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
assert experts is not None
experts.apply(
output,
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
activation,
global_num_experts,
expert_map,
a1q_scale,
a2_scale,
workspace13,
workspace2,
expert_tokens_meta,
apply_router_weight_on_input,
)
...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
) )
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin, prepare_fp8_moe_layer_for_marlin,
) )
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE, OCP_MX_BLOCK_SIZE,
...@@ -315,8 +315,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -315,8 +315,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
elif self.use_marlin: elif self.use_marlin:
(workspace, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale) = ( w13_weight, w2_weight, w13_weight_scale, w2_weight_scale = (
prepare_moe_fp8_layer_for_marlin( prepare_fp8_moe_layer_for_marlin(
layer, layer,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -324,7 +324,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -324,7 +324,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer.w2_weight_scale, layer.w2_weight_scale,
) )
) )
layer.workspace = workspace
# TODO(rob): once we apply refactor to Quark, switch to using # TODO(rob): once we apply refactor to Quark, switch to using
# replace_parameter for compatibility with reloading in RL. # replace_parameter for compatibility with reloading in RL.
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
......
...@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize im ...@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize im
create_flashinfer_prepare_finalize, create_flashinfer_prepare_finalize,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -58,9 +59,10 @@ def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: ...@@ -58,9 +59,10 @@ def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
) )
def rotate_flashinfer_fp8_moe_weights( def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor
): ):
"""Shuffle weights for for FI TRT-LLM Format"""
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
epilogue_tile_m = 128 epilogue_tile_m = 128
...@@ -105,16 +107,16 @@ def rotate_flashinfer_fp8_moe_weights( ...@@ -105,16 +107,16 @@ def rotate_flashinfer_fp8_moe_weights(
def register_scales_for_trtllm_fp8_per_tensor_moe( def register_scales_for_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module, layer: torch.nn.Module,
w13_weight_scale: torch.Tensor, w13_scale: torch.Tensor,
w13_input_scale: torch.Tensor, w13_input_scale: torch.Tensor,
w2_weight_scale: torch.Tensor, w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor, w2_input_scale: torch.Tensor,
) -> None: ) -> None:
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel""" """Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w13_scale=w13_weight_scale, w13_scale=w13_scale,
w13_input_scale=w13_input_scale, w13_input_scale=w13_input_scale,
w2_scale=w2_weight_scale, w2_scale=w2_scale,
w2_input_scale=w2_input_scale, w2_input_scale=w2_input_scale,
) )
layer.w2_input_scale_inv = 1.0 / w2_input_scale layer.w2_input_scale_inv = 1.0 / w2_input_scale
...@@ -123,7 +125,7 @@ def register_scales_for_trtllm_fp8_per_tensor_moe( ...@@ -123,7 +125,7 @@ def register_scales_for_trtllm_fp8_per_tensor_moe(
layer.output2_scales_scalar = g2_alphas layer.output2_scales_scalar = g2_alphas
def apply_flashinfer_per_tensor_scale_fp8( def apply_fi_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -139,16 +141,23 @@ def apply_flashinfer_per_tensor_scale_fp8( ...@@ -139,16 +141,23 @@ def apply_flashinfer_per_tensor_scale_fp8(
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
from vllm.model_executor.models.llama4 import Llama4MoE from vllm.model_executor.models.llama4 import Llama4MoE
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
assert ( assert (
hasattr(layer, "output1_scales_scalar") hasattr(layer, "output1_scales_scalar")
and hasattr(layer, "output1_scales_gate_scalar") and hasattr(layer, "output1_scales_gate_scalar")
and hasattr(layer, "output2_scales_scalar") and hasattr(layer, "output2_scales_scalar")
) )
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, ( # Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
"FusedMoE flashinfer kernels are only supported for Llama4" assert (
hasattr(layer, "output1_scales_scalar")
and hasattr(layer, "output1_scales_gate_scalar")
and hasattr(layer, "output2_scales_scalar")
) )
return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8(
is_llama4 = layer.custom_routing_function == Llama4MoE.custom_routing_function
assert is_llama4, "FusedMoE flashinfer kernels are only supported for Llama4"
return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
routing_logits=router_logits, routing_logits=router_logits,
routing_bias=routing_bias, routing_bias=routing_bias,
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -221,50 +230,6 @@ def select_cutlass_fp8_gemm_impl( ...@@ -221,50 +230,6 @@ def select_cutlass_fp8_gemm_impl(
) )
def flashinfer_cutlass_moe_fp8(
hidden_states: torch.Tensor,
layer: torch.nn.Module,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
use_deepseek_fp8_block_scale: bool = False,
moe: FusedMoEConfig | None = None,
) -> torch.Tensor:
quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
assert quant_config is not None
# Construct modular kernel with block-scale support when requested.
fused_experts = mk.FusedMoEModularKernel(
build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
),
select_cutlass_fp8_gemm_impl(
moe=moe,
quant_config=quant_config,
out_dtype=hidden_states.dtype,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
),
moe_parallel_config=layer.moe_parallel_config,
)
return fused_experts(
hidden_states,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
def get_flashinfer_moe_backend() -> FlashinferMoeBackend: def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = { backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS, "throughput": FlashinferMoeBackend.CUTLASS,
...@@ -301,3 +266,104 @@ def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> ...@@ -301,3 +266,104 @@ def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) ->
FlashinferMoeBackend.TENSORRT_LLM, FlashinferMoeBackend.TENSORRT_LLM,
) )
return backend in backends_supporting_global_sf return backend in backends_supporting_global_sf
def align_fp8_moe_weights_for_fi(
w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
used for GEMM to be divisible by a small alignment value. When this is
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
gate/up and down projection weights along the intermediate dim.
"""
# Current local intermediate size (per partition) is the K dimension of
# the down projection.
num_experts, hidden_size, intermediate = w2.shape
min_alignment = 16
padded_intermediate = round_up(intermediate, min_alignment)
if padded_intermediate == intermediate:
return w13, w2, intermediate
logger.info_once(
"Padding intermediate size from %d to %d for up/down projection weights.",
intermediate,
padded_intermediate,
scope="local",
)
up_mult = 2 if is_act_and_mul else 1
padded_gate_up_dim = up_mult * padded_intermediate
# Pad w13 and w2 along its intermediate dimension.
padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size))
padded_w13[:, : w13.shape[1], :] = w13
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
padded_w2[:, :, :intermediate] = w2
return padded_w13, padded_w2, padded_intermediate
def prepare_fp8_moe_layer_for_fi(
layer: torch.nn.Module,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w13_input_scale: torch.Tensor | None,
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor | None,
is_trtllm: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Convert Fp8 MoE weights to flashinfer kernel format
Note that for trtllm we update the model state dict
with the scale format needed for these kernels.
Note that for per-tensor, we update the layer's
intermediate size if the weights needed padding.
"""
assert hasattr(layer.moe_config, "is_act_and_mul")
block_quant = (
hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
)
# Some FI MoE kernels require internal alignment of 16
# for the gate-up proj. Pad the weights to respect this.
if not block_quant:
w13, w2, new_intermediate = align_fp8_moe_weights_for_fi(
w13,
w2,
layer.moe_config.is_act_and_mul,
)
layer.intermediate_size_per_partition = new_intermediate
# FI kernels require W31 layout rather than W13.
if layer.moe_config.is_act_and_mul:
w13 = swap_w13_to_w31(w13)
if block_quant:
w13_scale = swap_w13_to_w31(w13_scale)
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
# and registration of alpha scales. Note that we do not register
# as nn.Parameters since they are not needed for weight-reloading.
if is_trtllm and not block_quant:
assert w13_input_scale is not None
assert w2_input_scale is not None
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer,
w13_scale=w13_scale,
w13_input_scale=w13_input_scale,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
)
return w13, w2, w13_scale
...@@ -21,6 +21,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -21,6 +21,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_BLOCK_FP8_SUPPORTED,
all_close_1d,
per_tensor_dequantize,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
BlockQuantScaleParameter, BlockQuantScaleParameter,
...@@ -1350,6 +1352,29 @@ def deepgemm_post_process_fp8_weight_block( ...@@ -1350,6 +1352,29 @@ def deepgemm_post_process_fp8_weight_block(
return wq, dg_ws return wq, dg_ws
def prepare_fp8_moe_layer_for_deepgemm(
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
block_shape: tuple[int],
):
w13, w13_scale = deepgemm_post_process_fp8_weight_block(
wq=w13,
ws=w13_scale,
quant_block_shape=block_shape,
use_e8m0=is_deep_gemm_e8m0_used(),
)
w2, w2_scale = deepgemm_post_process_fp8_weight_block(
wq=w2,
ws=w2_scale,
quant_block_shape=block_shape,
use_e8m0=is_deep_gemm_e8m0_used(),
)
return w13, w2, w13_scale, w2_scale
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
"""Pad the weight tensor. This is an optimization on ROCm platform, which """Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory""" can benefit from tensors located far enough from one another in memory"""
...@@ -1584,7 +1609,49 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module): ...@@ -1584,7 +1609,49 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
replace_parameter(layer, scale_attr, dg_weight_scale) replace_parameter(layer, scale_attr, dg_weight_scale)
def expert_weight_is_col_major(x: torch.Tensor) -> bool: def process_fp8_weight_tensor_strategy_moe(
assert x.dim() == 3 weight: torch.Tensor,
b, m, n = x.shape weight_scales: torch.Tensor,
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m shard_size: int,
num_experts: int,
is_act_and_mul: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Process moe weights for tensor-wise quantization strategy."""
max_scales = weight_scales.max(dim=1).values
# For w1 case (i.e. not w13): just collapse the last dim since
# there is already just one scale per expert in this case.
if not is_act_and_mul:
assert weight_scales.shape[1] == 1
return weight, weight_scales.max()
# For w13 case (common): require single scale for w13 per expert, but
# on disk there is a scale for w1 and w3. Use the max to requantize.
for expert_id in range(num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
weight[expert_id][start : start + shard_size, :],
weight_scales[expert_id][shard_id],
)
weight[expert_id][start : start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_scales[expert_id]
)
start += shard_size
return weight, max_scales
def process_fp8_input_tensor_strategy_moe(
w13_input_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Process moe input scales for tensor-wise quantization strategy."""
if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale):
logger.info_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
return w13_input_scale.max(), w2_input_scale.max()
...@@ -496,7 +496,7 @@ def get__quant_fp8_method() -> QuantFP8: ...@@ -496,7 +496,7 @@ def get__quant_fp8_method() -> QuantFP8:
return _quant_fp8_method return _quant_fp8_method
def get_marlin_input_dtype(prefix): def get_marlin_input_dtype(prefix: str | None = None):
if envs.VLLM_MARLIN_INPUT_DTYPE is None: if envs.VLLM_MARLIN_INPUT_DTYPE is None:
return return
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "int8": elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "int8":
......
...@@ -8,6 +8,7 @@ import vllm._custom_ops as ops ...@@ -8,6 +8,7 @@ import vllm._custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
USE_FP32_REDUCE_DEFAULT, USE_FP32_REDUCE_DEFAULT,
get_marlin_input_dtype,
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_permute_bias, marlin_permute_bias,
marlin_permute_scales, marlin_permute_scales,
...@@ -197,26 +198,28 @@ def prepare_fp8_layer_for_marlin( ...@@ -197,26 +198,28 @@ def prepare_fp8_layer_for_marlin(
replace_parameter(layer, "bias", bias) replace_parameter(layer, "bias", bias)
def prepare_moe_fp8_layer_for_marlin( def prepare_fp8_moe_layer_for_marlin(
layer: torch.nn.Module, layer: torch.nn.Module,
w13_weight: torch.Tensor, w13_weight: torch.Tensor,
w2_weight: torch.Tensor, w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor, w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor, w2_weight_scale: torch.Tensor,
input_dtype: torch.dtype | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[ """
torch.Tensor, # workspace Shuffle weights and scales into marlin format.
torch.Tensor, # w13_weight
torch.Tensor, # w2_weight Note that this function has the side effect of adding a `workspace`
torch.Tensor, # w13_weight_scale attribute to the layer. This `workspace` does not need to be
torch.Tensor, # w2_weight_scale registered as a Parameter as it is not used during weight reloading.
]: """
logger.warning_once( logger.warning_once(
"Your GPU does not have native support for FP8 computation but " "Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will " "FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade " "be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads." "performance for compute-heavy workloads."
) )
input_dtype = get_marlin_input_dtype()
if input_dtype is not None and input_dtype.itemsize == 1: if input_dtype is not None and input_dtype.itemsize == 1:
raise NotImplementedError("Marlin W8A8 is not supported.") raise NotImplementedError("Marlin W8A8 is not supported.")
...@@ -227,7 +230,9 @@ def prepare_moe_fp8_layer_for_marlin( ...@@ -227,7 +230,9 @@ def prepare_moe_fp8_layer_for_marlin(
# WORKSPACE # WORKSPACE
device = layer.w13_weight.device device = layer.w13_weight.device
workspace = marlin_make_workspace_new(device, 4) # NOTE(rob): we do not need to register the workspace as a param
# because it is not used as part of the weight reloading process.
layer.workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device) perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT # WEIGHT
...@@ -310,13 +315,7 @@ def prepare_moe_fp8_layer_for_marlin( ...@@ -310,13 +315,7 @@ def prepare_moe_fp8_layer_for_marlin(
w13_weight_scale = permute_scales(w13_weight_scale, "w13") w13_weight_scale = permute_scales(w13_weight_scale, "w13")
w2_weight_scale = permute_scales(w2_weight_scale, "w2") w2_weight_scale = permute_scales(w2_weight_scale, "w2")
return ( return w13_weight, w2_weight, w13_weight_scale, w2_weight_scale
workspace,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
)
def pack_fp8_to_int32( def pack_fp8_to_int32(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment