Unverified Commit b471092d authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor][4/N] Marlin Fp8 Mk (#31036)

parent 93cabc41
...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.fp8 import ( ...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config, Fp8Config,
Fp8KVCacheMethod, Fp8KVCacheMethod,
Fp8LinearMethod, Fp8LinearMethod,
Fp8MoeBackend,
Fp8MoEMethod, Fp8MoEMethod,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -324,7 +325,10 @@ def test_fp8_reloading( ...@@ -324,7 +325,10 @@ def test_fp8_reloading(
weight_loader=default_weight_loader, weight_loader=default_weight_loader,
) )
# Fp8LinearMethod uses use_marlin
# Fp8MoEMethod uses fp8_backend
method.use_marlin = use_marlin method.use_marlin = use_marlin
method.fp8_backend = Fp8MoeBackend.MARLIN if use_marlin else None
# capture weights format during loading # capture weights format during loading
original_metadata = [ original_metadata = [
......
...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( ...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_Scheme, OCP_MX_Scheme,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_triton_kernels from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
...@@ -39,6 +40,7 @@ if has_triton_kernels(): ...@@ -39,6 +40,7 @@ if has_triton_kernels():
def _get_config_dtype_str( def _get_config_dtype_str(
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_fp8_w8a16: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
ocp_mx_scheme: str | None = None, ocp_mx_scheme: str | None = None,
...@@ -50,6 +52,8 @@ def _get_config_dtype_str( ...@@ -50,6 +52,8 @@ def _get_config_dtype_str(
""" """
if use_fp8_w8a8: if use_fp8_w8a8:
return "fp8_w8a8" return "fp8_w8a8"
elif use_fp8_w8a16:
return "fp8_w8a16"
elif use_int8_w8a16: elif use_int8_w8a16:
return "int8_w8a16" return "int8_w8a16"
elif use_int4_w4a16: elif use_int4_w4a16:
...@@ -319,6 +323,10 @@ class FusedMoEQuantConfig: ...@@ -319,6 +323,10 @@ class FusedMoEQuantConfig:
def use_int8_w8a16(self) -> bool: def use_int8_w8a16(self) -> bool:
return self._a1.dtype is None and self._w1.dtype == torch.int8 return self._a1.dtype is None and self._w1.dtype == torch.int8
@property
def use_fp8_w8a16(self) -> bool:
return self._a1.dtype is None and self._w1.dtype == current_platform.fp8_dtype()
@property @property
def use_int4_w4a16(self) -> bool: def use_int4_w4a16(self) -> bool:
return self._a1.dtype is None and self._w1.dtype == "int4" return self._a1.dtype is None and self._w1.dtype == "int4"
...@@ -362,6 +370,7 @@ class FusedMoEQuantConfig: ...@@ -362,6 +370,7 @@ class FusedMoEQuantConfig:
""" """
return _get_config_dtype_str( return _get_config_dtype_str(
use_fp8_w8a8=self.use_fp8_w8a8, use_fp8_w8a8=self.use_fp8_w8a8,
use_fp8_w8a16=self.use_fp8_w8a16,
use_int8_w8a16=self.use_int8_w8a16, use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16, use_int4_w4a16=self.use_int4_w4a16,
ocp_mx_scheme=self.ocp_mx_scheme, ocp_mx_scheme=self.ocp_mx_scheme,
...@@ -680,7 +689,6 @@ def int4_w4a16_moe_quant_config( ...@@ -680,7 +689,6 @@ def int4_w4a16_moe_quant_config(
) -> FusedMoEQuantConfig: ) -> FusedMoEQuantConfig:
""" """
Construct a quant config for 16-bit float activations and int4 weights. Construct a quant config for 16-bit float activations and int4 weights.
Note: Activations are pre-quantized.
""" """
group_shape = GroupShape(*block_shape) if block_shape is not None else None group_shape = GroupShape(*block_shape) if block_shape is not None else None
return FusedMoEQuantConfig( return FusedMoEQuantConfig(
...@@ -691,6 +699,27 @@ def int4_w4a16_moe_quant_config( ...@@ -691,6 +699,27 @@ def int4_w4a16_moe_quant_config(
) )
def fp8_w8a16_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for 16-bit float activations and fp8 weights.
"""
group_shape = GroupShape(*block_shape) if block_shape is not None else None
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(),
_a2=FusedMoEQuantDesc(),
_w1=FusedMoEQuantDesc(
current_platform.fp8_dtype(), group_shape, w1_scale, None, None
),
_w2=FusedMoEQuantDesc(
current_platform.fp8_dtype(), group_shape, w2_scale, None, None
),
)
def int8_w8a16_moe_quant_config( def int8_w8a16_moe_quant_config(
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
...@@ -700,7 +729,6 @@ def int8_w8a16_moe_quant_config( ...@@ -700,7 +729,6 @@ def int8_w8a16_moe_quant_config(
) -> FusedMoEQuantConfig: ) -> FusedMoEQuantConfig:
""" """
Construct a quant config for 16-bit float activations and int8 weights. Construct a quant config for 16-bit float activations and int8 weights.
Note: Activations are pre-quantized.
""" """
group_shape = GroupShape(*block_shape) if block_shape is not None else None group_shape = GroupShape(*block_shape) if block_shape is not None else None
return FusedMoEQuantConfig( return FusedMoEQuantConfig(
......
...@@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( ...@@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size, batched_moe_align_block_size,
moe_align_block_size, moe_align_block_size,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceDelegate,
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
...@@ -26,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -26,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_moe_intermediate_size, marlin_moe_intermediate_size,
marlin_quant_input, marlin_quant_input,
) )
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
...@@ -542,9 +540,11 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -542,9 +540,11 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
is_k_full: bool = True, is_k_full: bool = True,
): ):
# TODO (varun) : Enable activation quantization # TODO (varun) : Enable activation quantization
assert quant_config.use_mxfp4_w4a16 or quant_config.use_int4_w4a16, ( assert (
"Supports only mxfp4_w4a16 or int4_w4a16" quant_config.use_mxfp4_w4a16
) or quant_config.use_int4_w4a16
or quant_config.use_fp8_w8a16
), "Supports only mxfp4_w4a16, int4_w4a16 or fp8_w8a16"
self.w13_g_idx = w13_g_idx self.w13_g_idx = w13_g_idx
self.w2_g_idx = w2_g_idx self.w2_g_idx = w2_g_idx
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
...@@ -555,11 +555,17 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -555,11 +555,17 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
@property @property
def quant_type_id(self) -> int: def quant_type_id(self) -> int:
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4 # uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
return ( if self.quant_config.use_int4_w4a16:
scalar_types.uint4b8.id return scalar_types.uint4b8.id
if self.quant_config.use_int4_w4a16 elif self.quant_config.use_mxfp4_w4a16:
else scalar_types.float4_e2m1f.id return scalar_types.float4_e2m1f.id
) elif (
self.quant_config.use_fp8_w8a16
and current_platform.fp8_dtype() == torch.float8_e4m3fn
):
return scalar_types.float8_e4m3fn.id
else:
raise NotImplementedError("Unsupported quantization type.")
def moe_problem_size( def moe_problem_size(
self, self,
...@@ -711,16 +717,6 @@ class MarlinExperts(MarlinExpertsBase): ...@@ -711,16 +717,6 @@ class MarlinExperts(MarlinExpertsBase):
ops.moe_sum(input, output) ops.moe_sum(input, output)
def modular_marlin_fused_moe(
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config),
shared_experts,
)
class BatchedMarlinExperts(MarlinExpertsBase): class BatchedMarlinExperts(MarlinExpertsBase):
def __init__( def __init__(
self, self,
......
...@@ -32,8 +32,8 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -32,8 +32,8 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType, RoutingMethodType,
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
fp8_w8a16_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearBase, LinearBase,
...@@ -95,7 +95,6 @@ from vllm.model_executor.parameter import ( ...@@ -95,7 +95,6 @@ from vllm.model_executor.parameter import (
) )
from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
is_deep_gemm_supported, is_deep_gemm_supported,
...@@ -729,7 +728,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -729,7 +728,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) )
self.marlin_input_dtype = None self.marlin_input_dtype = None
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
...@@ -1048,7 +1046,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1048,7 +1046,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
layer.w13_weight.data = w13_weight.data layer.w13_weight.data = w13_weight.data
if self.use_marlin: if self.fp8_backend == Fp8MoeBackend.MARLIN:
prepare_moe_fp8_layer_for_marlin( prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype layer, False, input_dtype=self.marlin_input_dtype
) )
...@@ -1091,10 +1089,17 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1091,10 +1089,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) )
self.use_inplace = False self.use_inplace = False
elif self.fp8_backend in [Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.TRITON]: elif self.fp8_backend in [
Fp8MoeBackend.DEEPGEMM,
Fp8MoeBackend.TRITON,
Fp8MoeBackend.MARLIN,
]:
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
) )
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoEP,
) )
...@@ -1102,12 +1107,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1102,12 +1107,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
config = self.get_fused_moe_quant_config(layer) config = self.get_fused_moe_quant_config(layer)
assert config is not None assert config is not None
self.moe_quant_config = config self.moe_quant_config = config
self.kernel = mk.FusedMoEModularKernel( use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
MoEPrepareAndFinalizeNoEP(), allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
TritonOrDeepGemmExperts( moe_kernel = (
MarlinExperts(quant_config=self.moe_quant_config)
if use_marlin
else TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM), allow_deep_gemm=allow_deep_gemm,
), )
)
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), moe_kernel
) )
self.use_inplace = True self.use_inplace = True
...@@ -1116,9 +1128,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1116,9 +1128,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
if ( if (
current_platform.is_xpu() self.rocm_aiter_moe_enabled
or self.rocm_aiter_moe_enabled or self.fp8_backend == Fp8MoeBackend.MARLIN
or self.use_marlin
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
): ):
return None return None
...@@ -1150,7 +1161,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1150,7 +1161,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
) )
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( assert (
self.fp8_backend != Fp8MoeBackend.MARLIN
) and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet." "Marlin and ROCm AITER are not supported with all2all yet."
) )
...@@ -1207,8 +1220,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1207,8 +1220,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
if self.use_marlin: if self.fp8_backend == Fp8MoeBackend.MARLIN:
return None return fp8_w8a16_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
block_shape=self.weight_block_size,
)
return fp8_w8a8_moe_quant_config( return fp8_w8a8_moe_quant_config(
w1_scale=( w1_scale=(
...@@ -1314,29 +1331,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1314,29 +1331,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_map=layer.expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
elif self.use_marlin:
# TODO(rob): convert this to MK.
assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
result = fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
None,
None,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
else: else:
result = self.kernel( result = self.kernel(
x, x,
...@@ -1495,7 +1489,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): ...@@ -1495,7 +1489,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
replace_parameter(layer, "w2_weight", shuffled_w2) replace_parameter(layer, "w2_weight", shuffled_w2)
# Rushuffle weights for MARLIN if needed. # Rushuffle weights for MARLIN if needed.
if self.use_marlin: if self.fp8_backend == Fp8MoeBackend.MARLIN:
prepare_moe_fp8_layer_for_marlin( prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype layer, False, input_dtype=self.marlin_input_dtype
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment