Unverified Commit 168033d5 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Support mxfp4 for GPT-OSS (#8843)


Co-authored-by: default avatarCo-author fzyzcjy <ch271828n@outlook.com>
Co-authored-by: default avatarfzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
Co-authored-by: default avatarzhuofan1123 <zhuofanl@nvidia.com>
Co-authored-by: default avatarliz-badada <jinyanc@nvidia.com>
Co-authored-by: default avatarxutizhou <xutingz@nvidia.com>
Co-authored-by: default avatarlinhu-nv <linhu@nvidia.com>
parent cbbb7383
...@@ -389,7 +389,7 @@ class FusedMoE(torch.nn.Module): ...@@ -389,7 +389,7 @@ class FusedMoE(torch.nn.Module):
# Narrow parameter and load. # Narrow parameter and load.
if is_bias: if is_bias:
# this expert_data is a bias, not weight, # this expert_data is a bias, not weight,
# for w2_bias in TP, it does not need to be sharded # for w2_weight_bias in TP, it does not need to be sharded
shard_size = expert_data.shape[-1] shard_size = expert_data.shape[-1]
else: else:
# this parameter is a weight matrix # this parameter is a weight matrix
...@@ -410,10 +410,6 @@ class FusedMoE(torch.nn.Module): ...@@ -410,10 +410,6 @@ class FusedMoE(torch.nn.Module):
if not is_bias and not self.use_presharded_weights: if not is_bias and not self.use_presharded_weights:
if self.use_triton_kernels: if self.use_triton_kernels:
loaded_weight = loaded_weight.transpose(-2, -1) loaded_weight = loaded_weight.transpose(-2, -1)
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
raise ValueError(
f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
)
loaded_weight = loaded_weight.narrow( loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size shard_dim, shard_size * tp_rank, shard_size
) )
...@@ -461,9 +457,25 @@ class FusedMoE(torch.nn.Module): ...@@ -461,9 +457,25 @@ class FusedMoE(torch.nn.Module):
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
weight_name: str, weight_name: str,
shard_id: str, shard_id: str,
expert_id: int, expert_id: Optional[int],
) -> None: ) -> None:
# if expert_id is None, then
# all the experts are loaded at the same time
if (
not expert_id
and self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
):
if "bias" in weight_name:
dim1 = loaded_weight.shape[1]
param.data[:, :dim1].copy_(loaded_weight)
else:
dim1 = loaded_weight.shape[1]
dim2 = loaded_weight.shape[2]
param.data[:, :dim1, :dim2].copy_(loaded_weight)
return
global_expert_location_metadata = get_global_expert_location_metadata() global_expert_location_metadata = get_global_expert_location_metadata()
if global_expert_location_metadata is None: if global_expert_location_metadata is None:
self._weight_loader_impl( self._weight_loader_impl(
...@@ -502,6 +514,7 @@ class FusedMoE(torch.nn.Module): ...@@ -502,6 +514,7 @@ class FusedMoE(torch.nn.Module):
shard_id: str, shard_id: str,
expert_id: int, expert_id: int,
) -> None: ) -> None:
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1: if expert_id == -1:
return return
...@@ -705,6 +718,18 @@ class FusedMoE(torch.nn.Module): ...@@ -705,6 +718,18 @@ class FusedMoE(torch.nn.Module):
) -> None: ) -> None:
tp_rank = self.moe_tp_rank tp_rank = self.moe_tp_rank
if self.quant_config is not None and self.quant_config.get_name() == "mxfp4":
if "bias" in weight_name:
dim1 = loaded_weight.shape[1]
param.data[:, :dim1].copy_(loaded_weight)
elif "scale" in weight_name:
param.data.copy_(loaded_weight)
else:
dim1 = loaded_weight.shape[1]
dim2 = loaded_weight.shape[2]
param.data[:, :dim1, :dim2].copy_(loaded_weight)
return
# compressed-tensors checkpoints with packed weights are stored flipped # compressed-tensors checkpoints with packed weights are stored flipped
# TODO: check self.quant_method.quant_config.quant_format # TODO: check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality # against known CompressionFormat enum values that have this quality
...@@ -854,6 +879,33 @@ class FusedMoE(torch.nn.Module): ...@@ -854,6 +879,33 @@ class FusedMoE(torch.nn.Module):
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"), ("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
] ]
@classmethod
def make_expert_params_mapping_fused_mxfp4(
cls,
ckpt_gate_up_proj_name: str,
ckpt_down_proj_name: str,
ckpt_gate_up_proj_bias_name: str,
ckpt_down_proj_bias_name: str,
ckpt_gate_up_proj_scale_name: str,
ckpt_down_proj_scale_name: str,
):
return [
("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
(
"experts.w13_weight_bias",
f"experts.{ckpt_gate_up_proj_bias_name}",
"w13",
),
("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
(
"experts.w13_weight_scale",
f"experts.{ckpt_gate_up_proj_scale_name}",
"w13",
),
("experts.w2_weight_scale", f"experts.{ckpt_down_proj_scale_name}", "w2"),
]
@classmethod @classmethod
def make_expert_input_scale_params_mapping( def make_expert_input_scale_params_mapping(
cls, cls,
......
...@@ -186,8 +186,10 @@ def triton_kernel_fused_experts( ...@@ -186,8 +186,10 @@ def triton_kernel_fused_experts(
def triton_kernel_moe_with_bias_forward( def triton_kernel_moe_with_bias_forward(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w1_pcg,
b1: torch.Tensor, b1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w2_pcg,
b2: torch.Tensor, b2: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
inplace: bool = False, inplace: bool = False,
...@@ -209,13 +211,15 @@ def triton_kernel_moe_with_bias_forward( ...@@ -209,13 +211,15 @@ def triton_kernel_moe_with_bias_forward(
return triton_kernel_fused_experts_with_bias( return triton_kernel_fused_experts_with_bias(
hidden_states, hidden_states,
w1, w1=w1,
b1, w1_pcg=w1_pcg,
w2, b1=b1,
b2, w2=w2,
routing_data, w2_pcg=w2_pcg,
gather_idx, b2=b2,
scatter_idx, routing_data=routing_data,
gather_indx=gather_idx,
scatter_indx=scatter_idx,
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
...@@ -235,8 +239,10 @@ def triton_kernel_moe_with_bias_forward( ...@@ -235,8 +239,10 @@ def triton_kernel_moe_with_bias_forward(
def triton_kernel_fused_experts_with_bias( def triton_kernel_fused_experts_with_bias(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w1_pcg,
b1: torch.Tensor, b1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w2_pcg,
b2: torch.Tensor, b2: torch.Tensor,
routing_data: RoutingData, routing_data: RoutingData,
gather_indx: GatherIndx, gather_indx: GatherIndx,
...@@ -267,8 +273,10 @@ def triton_kernel_fused_experts_with_bias( ...@@ -267,8 +273,10 @@ def triton_kernel_fused_experts_with_bias(
# type check # type check
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16" assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
assert w1.dtype == torch.bfloat16, "w1 must be bfloat16" for w in (w1, w2):
assert w2.dtype == torch.bfloat16, "w2 must be bfloat16" # TODO assert bf16 or mxfp4
# assert (w.dtype == torch.bfloat16) or check-is-mxfp4, f"w must be bfloat16 or mxfp4 {w1.dtype=}"
pass
# Shape check # Shape check
assert hidden_states.ndim == 2, "hidden_states must be 2D" assert hidden_states.ndim == 2, "hidden_states must be 2D"
...@@ -287,13 +295,15 @@ def triton_kernel_fused_experts_with_bias( ...@@ -287,13 +295,15 @@ def triton_kernel_fused_experts_with_bias(
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
device = "cuda" # TODO maybe completely remove this branch
optg = dict() if w1.dtype == torch.bfloat16:
w1, w1_flex = quantize(w1, "bf16", device, **optg) device = "cuda"
w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex)) optg = dict()
w1, w1_flex = quantize(w1, "bf16", device, **optg)
w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex))
w2, w2_flex = quantize(w2, "bf16", device, **optg) w2, w2_flex = quantize(w2, "bf16", device, **optg)
w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex)) w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex))
act = FusedActivation( act = FusedActivation(
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
......
...@@ -47,7 +47,7 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config ...@@ -47,7 +47,7 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig, CompressedTensorsConfig,
) )
from sglang.srt.utils import mxfp_supported from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
is_mxfp_supported = mxfp_supported() is_mxfp_supported = mxfp_supported()
if is_mxfp_supported: if is_mxfp_supported:
...@@ -66,6 +66,7 @@ from sglang.srt.layers.quantization.modelopt_quant import ( ...@@ -66,6 +66,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptFp8Config, ModelOptFp8Config,
) )
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
from sglang.srt.layers.quantization.petit import PetitNvFp4Config from sglang.srt.layers.quantization.petit import PetitNvFp4Config
from sglang.srt.layers.quantization.qoq import QoQConfig from sglang.srt.layers.quantization.qoq import QoQConfig
from sglang.srt.layers.quantization.utils import get_linear_quant_method from sglang.srt.layers.quantization.utils import get_linear_quant_method
...@@ -90,7 +91,16 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -90,7 +91,16 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"w4afp8": W4AFp8Config, "w4afp8": W4AFp8Config,
"petit_nvfp4": PetitNvFp4Config, "petit_nvfp4": PetitNvFp4Config,
} }
if is_mxfp_supported:
if is_cuda():
BASE_QUANTIZATION_METHODS.update(
{
"quark": Mxfp4Config,
"mxfp4": Mxfp4Config,
}
)
elif is_mxfp_supported and is_hip():
BASE_QUANTIZATION_METHODS.update( BASE_QUANTIZATION_METHODS.update(
{ {
"quark": MxFp4Config, "quark": MxFp4Config,
......
...@@ -50,315 +50,50 @@ use_dynamic_mxfp4_linear = get_bool_env_var("SGLANG_USE_DYNAMIC_MXFP4_linear") ...@@ -50,315 +50,50 @@ use_dynamic_mxfp4_linear = get_bool_env_var("SGLANG_USE_DYNAMIC_MXFP4_linear")
OCP_MX_BLOCK_SIZE = 32 OCP_MX_BLOCK_SIZE = 32
class MxFp4Config(QuantizationConfig): class Mxfp4Config(QuantizationConfig):
def __init__( def __init__(self, ignored_layers: Optional[list[str]] = None):
self,
is_checkpoint_fp4_serialized: bool = False,
quant_config: dict[str, Any] = None,
kv_cache_group: Optional[list[str]] = None,
kv_cache_config: Optional[dict[str, Any]] = None,
pack_method: str = "reorder",
ignored_layers: Optional[List[str]] = None,
):
super().__init__() super().__init__()
if kv_cache_group is None: self.ignored_layers = ignored_layers
kv_cache_group = []
self.is_checkpoint_fp4_serialized = is_checkpoint_fp4_serialized
self.quant_config = quant_config
self.kv_cache_group = kv_cache_group
self.kv_cache_config = kv_cache_config
self.pack_method = pack_method
self.packed_modules_mapping = (
self.quant_config["packed_modules_mapping"]
if is_checkpoint_fp4_serialized
else None
)
self.ignored_layers = ignored_layers or [] @classmethod
def from_config(cls, config):
def get_supported_act_dtypes(cls) -> list[torch.dtype]: return cls()
return [torch.float16, torch.bfloat16]
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 70 return 80
def get_name(self) -> str:
return "fp4"
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
# Check if the layer is skipped for quantization.
if len(self.ignored_layers) > 0 and should_ignore_layer(
prefix,
ignore=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
if self.is_checkpoint_fp4_serialized:
scheme = self.get_scheme(layer=layer, layer_name=prefix)
layer.scheme = scheme
return MxFp4LinearMethod(self)
elif use_dynamic_mxfp4_linear:
return MxFp4LinearMethod(self)
else:
return UnquantizedLinearMethod()
if isinstance(layer, RadixAttention):
return MxFp4KVCacheMethod(self)
if isinstance(layer, FusedMoE):
return MxFp4MoEMethod.get_moe_method(self, module=layer, layer_name=prefix)
return None
@classmethod @classmethod
def from_config(cls, config: dict[str, Any]) -> "MxFp4Config": def get_name(cls) -> QuantizationMethods:
if not mxfp_supported(): return "mxfp4"
platform = torch.cuda.get_device_properties(0).gcnArchName
raise ValueError(
f"Current platform {platform} not support mxfp4 computation"
)
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp4_serialized = (
True if quant_method else False
) # "quark" in quant_method
kv_cache_group = []
pack_method = None
if is_checkpoint_fp4_serialized:
export_config = config.get("export")
if export_config is None:
raise ValueError(
"The export key should be included in "
"the configurations of Quark quantized model"
)
kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
pack_method = cast(str, export_config.get("pack_method"))
# In the export model of quark, the quantization configuration
# of kv_cache is stored in layer_quant_config. First, it is
# judged whether kv_cache_group exists, and then it is judged
# whether layer_quant_config has a quantization configuration
# that matches kv_cache.
if len(kv_cache_group) == 0:
kv_cache_config = None
else:
kv_cache_set = set(kv_cache_group)
layer_quant_config = cast(dict[str, Any], config.get("layer_quant_config"))
layer_quant_names = list(layer_quant_config.keys())
layer_quant_set = set(layer_quant_names)
if not kv_cache_set.issubset(layer_quant_set):
raise ValueError(
"The Quark quantized model has the "
"kv_cache_group parameter setting, "
"but no kv_cache quantization settings "
"were found in the quantization "
"configuration."
)
q_configs = [ @classmethod
cast(dict[str, Any], layer_quant_config.get(name)) def get_supported_act_dtypes(cls) -> list[torch.dtype]:
for name in kv_cache_group return [torch.bfloat16]
]
if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs):
raise ValueError(
"The quantization method used for kv_cache should "
"be the same, but the quantization method for the "
"kv_cache layer in the config is different."
)
kv_cache_config = q_configs[0].get("output_tensors")
if kv_cache_config is None:
raise ValueError("The kv_cache quantization configuration is empty.")
# Since we have already set kv_cache quantization configurations,
# we will remove the quantization configuration for the
# output_tensors corresponding to the kv_cache layer.
for q_config in q_configs:
q_config["output_tensors"] = None
# In case q_proj output is also quantized, remove the configuration
# to keep qkv consistency.
q_proj_q_config = cast(dict[str, Any], layer_quant_config.get("*q_proj"))
if q_proj_q_config is not None:
q_proj_q_config["output_tensors"] = None
ignored_layers = cls.get_from_keys_or(config, ["exclude"], None)
return cls(
is_checkpoint_fp4_serialized=is_checkpoint_fp4_serialized,
quant_config=config,
kv_cache_group=kv_cache_group,
kv_cache_config=kv_cache_config,
pack_method=pack_method,
ignored_layers=ignored_layers,
)
@classmethod @classmethod
def get_config_filenames(cls) -> list[str]: def get_config_filenames(cls) -> list[str]:
return [] return []
def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool: def get_quant_method(
capability_tuple = get_device_capability() self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if capability_tuple is not None: from vllm.attention.layer import Attention # Avoid circular import
assert 0 <= capability_tuple[1] < 10
capability = capability_tuple[0] * 10 + capability_tuple[1]
supported = capability >= min_capability
if error and not supported:
raise RuntimeError(
"Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.",
)
return supported
else:
return False
def _is_mx_fp4( if isinstance(layer, LinearBase):
self, if self.ignored_layers and is_layer_skipped(
weight_quant: Optional[dict[str, Any]], prefix=prefix,
input_quant: Optional[dict[str, Any]], ignored_layers=self.ignored_layers,
) -> bool: fused_mapping=self.packed_modules_mapping,
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
logger.debug(
"Quark model is not in MX-FP4 format: "
"weight_quant or input_quant not set"
)
return False
# Input and weight dtype needs to be fp4.
if weight_quant.get("dtype") != "fp4" or input_quant.get("dtype") != "fp4":
logger.debug("Quark model is not in MX-FP4 format: dtype not fp4")
return False
# Input and weight qscheme needs to be per group.
if (
weight_quant.get("qscheme") != "per_group"
or input_quant.get("qscheme") != "per_group"
):
logger.debug("Quark model is not in MX-FP4 format: not per_group")
return False
# Input and weight group size needs to be 32.
if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32:
logger.debug("Quark model is not in MX-FP4 format: not group_size=32")
return False
# Weights need to use static quantization.
if weight_quant.get("is_dynamic") is True:
logger.debug("Quark model is not in MX-FP4 format: not weight static")
return False
# Activations need to use dynamic quantization.
if input_quant.get("is_dynamic") is False:
logger.debug("Quark model is not in MX-FP4 format: not activation dynamic")
return False
# Activations and weight scales need to be in e8m0 format.
if (
weight_quant.get("scale_format") != "e8m0"
or input_quant.get("scale_format") != "e8m0"
):
logger.debug("Quark model is not in MX-FP4 format: not scale_format e8m0")
return False
return True
def _find_matched_config(
self, layer_name: str, module: torch.nn.Module
) -> dict[str, Any]:
proj_name = layer_name.split(".")[-1]
if proj_name in self.packed_modules_mapping:
shard_proj_names = self.packed_modules_mapping[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
layer_name.replace(proj_name, shard_proj_name)
for shard_proj_name in shard_proj_names
]
shard_configs = [
self._find_matched_config(shard_name, module)
for shard_name in shard_names
]
if not all(
deep_compare(q_config, shard_configs[0]) for q_config in shard_configs
): ):
raise ValueError( return UnquantizedLinearMethod()
f"Found a different quantization configuration for " raise NotImplementedError("Mxfp4 linear layer is not implemented")
f"{shard_proj_names=} in {layer_name=}. vLLM " elif isinstance(layer, FusedMoE):
"requires all to use the same scheme." return Mxfp4MoEMethod(layer.moe_config)
) elif isinstance(layer, Attention):
return shard_configs[0] raise NotImplementedError("Mxfp4 attention layer is not implemented")
else: return None
layer_quant_config = cast(
dict[str, Any], self.quant_config.get("layer_quant_config")
)
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
return layer_quant_config[name_pattern]
layer_type = cast(str, type(module))
layer_type_quant_config = cast(
dict[str, Any], self.quant_config.get("layer_type_quant_config")
)
if layer_type in layer_type_quant_config:
return layer_type_quant_config[layer_type]
global_quant_config = cast(
dict[str, Any], self.quant_config.get("global_quant_config")
)
return global_quant_config
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
if config.get("output_tensors") or config.get("bias"):
raise NotImplementedError(
"Currently, Quark models with output_tensors "
"and bias quantized are not supported"
)
weight_config = cast(dict[str, Any], config.get("weight"))
input_config = cast(dict[str, Any], config.get("input_tensors"))
if self._is_mx_fp4(weight_config, input_config):
return QuarkW4A4MXFP4(weight_config, input_config)
raise NotImplementedError(
"No quark compatible scheme was found. "
f"{weight_config=}, "
f"{input_config=}"
)
def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
layer_quant_config = self._find_matched_config(layer_name, layer)
# Find the quant_scheme
scheme = self._get_scheme_from_config(layer_quant_config)
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
return scheme
def get_scaled_act_names(self) -> List[str]:
return []
class MxFp4LinearMethod(LinearMethodBase): class MxFp4LinearMethod(LinearMethodBase):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import importlib
import logging
from typing import TYPE_CHECKING, Callable, List, Optional
import torch
from torch.nn.parameter import Parameter
# from vllm.model_executor.layers.fused_moe import (
# FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
# FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import (
direct_register_custom_op,
is_cuda,
is_flashinfer_available,
is_hip,
next_power_of_2,
round_up,
set_weight_attrs,
)
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
if is_flashinfer_available():
# from flashinfer.fused_moe import cutlass_fused_moe
from flashinfer import (
mxfp8_quantize,
shuffle_matrix_a,
shuffle_matrix_sf_a,
trtllm_fp4_block_scale_moe,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
OCP_MX_BLOCK_SIZE = 32
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
from triton_kernels.numerics import InFlexData
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
mx_axis=1
)
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=1, num_warps=num_warps
)
if is_cuda() and torch.cuda.get_device_capability()[0] == 10:
constraints = {
"is_persistent": True,
"epilogue_subtile": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
# transpose the tensor so that the quantization axis is on dim1
quant_tensor = quant_tensor.transpose(-2, -1)
scale = scale.transpose(-2, -1)
quant_tensor = convert_layout(
wrap_torch_tensor(quant_tensor, dtype=FP4), value_layout, **value_layout_opts
)
scale = convert_layout(wrap_torch_tensor(scale), scale_layout, **scale_layout_opts)
return quant_tensor, InFlexData(), scale
def _dequant_mxfp4(
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
) -> torch.Tensor:
try:
from quark.torch.kernel import mx
except ImportError as err:
raise ImportError(
"The package `amd-quark` is required to use "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`."
) from err
return mx.dq_mxfp4(x, scale, float_dtype)
def _dequant_mxfp4_fake(
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
) -> torch.Tensor:
return torch.empty(
(*x.shape[:-1], x.shape[-1] * 2), dtype=float_dtype, device=x.device
)
def _quant_dequant_mxfp4(
x: torch.Tensor, scale_calculation_mode: str = "even"
) -> torch.Tensor:
try:
from quark.torch.kernel import mx
except ImportError as err:
raise ImportError(
"The package `amd-quark` is required to use "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`."
) from err
return mx.qdq_mxfp4(x, scale_calculation_mode)
def _quant_dequant_mxfp4_fake(
x: torch.Tensor, scale_calculation_mode: str = "even"
) -> torch.Tensor:
return torch.empty_like(x)
try:
direct_register_custom_op(
op_name="dequant_mxfp4",
op_func=_dequant_mxfp4,
mutates_args=[],
fake_impl=_dequant_mxfp4_fake,
)
dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
except AttributeError as error:
raise error
try:
direct_register_custom_op(
op_name="quant_dequant_mxfp4",
op_func=_quant_dequant_mxfp4,
mutates_args=[],
fake_impl=_quant_dequant_mxfp4_fake,
)
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
except AttributeError as error:
raise error
class Mxfp4Config(QuantizationConfig):
def __init__(self, ignored_layers: Optional[list[str]] = None):
super().__init__()
self.ignored_layers = ignored_layers
@classmethod
def from_config(cls, config):
return cls()
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_name(cls) -> str:
return "mxfp4"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.float16]
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
if isinstance(layer, LinearBase):
if self.ignored_layers and is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
return Mxfp4MoEMethod(use_triton_kernels=True, with_bias=True)
else:
raise NotImplementedError("Mxfp4 attention layer is not implemented")
return None
def get_scaled_act_names(self) -> List[str]:
return []
class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, use_triton_kernels: bool = True, with_bias: bool = True):
super().__init__()
self.topk_indices_dtype = None
self.use_triton_kernels = use_triton_kernels
self.with_bias = with_bias
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None
if torch.cuda.is_available() and has_triton_kernels:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward as _tk_forward,
)
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
)
self.triton_kernel_moe_forward = _tk_forward
self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# print(f"hi {self=} create_weights {layer=}")
self.num_experts = num_experts
weight_dtype = torch.uint8
scale_dtype = torch.uint8
intermediate_size *= 2
mxfp4_block = 32
self.intermediate_size = intermediate_size
self.hidden_size = hidden_size
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.zeros(
num_experts, 2 * intermediate_size, hidden_size // 2, dtype=weight_dtype
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size,
hidden_size // mxfp4_block,
dtype=scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w13_weight_bias = torch.nn.Parameter(
torch.zeros(num_experts, 2 * intermediate_size, dtype=torch.bfloat16),
requires_grad=False,
)
layer.register_parameter("w13_weight_bias", w13_weight_bias)
set_weight_attrs(w13_weight_bias, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.zeros(
num_experts, hidden_size, intermediate_size // 2, dtype=weight_dtype
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size // mxfp4_block,
dtype=scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w2_weight_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=torch.bfloat16),
requires_grad=False,
)
layer.register_parameter("w2_weight_bias", w2_weight_bias)
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
def process_weights_after_loading(self, layer):
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
w13_weight_bias = layer.w13_weight_bias.to(torch.float32)
w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False)
layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False)
num_warps = 8
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps
)
self.w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
)
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight
# need to delete the original weights to save memory on single GPU
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = None
layer.w2_weight = None
torch.cuda.empty_cache()
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
# Number of tokens in the input tensor.
num_tokens = x.shape[0]
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
# - 1.0 means perfect expert distribution.
# - > 1.0 means some experts have more
# tokens than the perfect distribution.
# - < 1.0 does not make sense.
imbalance_factor = 1.3
# Calculate the number of tokens per expert
# assuming perfect distribution.
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor:
# avoid import error when triton_kernel is not installed
# from vllm.model_executor.layers.fused_moe.triton_kernels_moe import (
# triton_kernel_moe_forward)
"""
if (envs.VLLM_USE_FLASHINFER_MXFP4_MOE
or envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE):
assert not self.moe.use_ep, (
"EP is not supported for flashinfer mxfp4 moe backend yet.")
if envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE:
assert x.dtype == torch.bfloat16
x_quant = x
x_scale = None
else:
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16),
None, # routing_bias
x_quant,
x_scale,
layer.w13_weight, # uint8 (e2m1 x 2)
layer.w13_weight_scale, # uint8 (e4m3 x 2)
layer.w13_weight_bias, # fp32 per expert per channel
layer.gemm1_alpha, # fp32 per expert
layer.gemm1_beta, # fp32 per expert
layer.gemm1_clamp_limit, # fp32 per expert
layer.w2_weight, # uint8 (e2m1 x 2)
layer.w2_weight_scale, # ue8m0
layer.w2_weight_bias, # fp32 per expert per channel
None, # output1_scale_scalar
None, # output1_scale_gate_scalar
None, # output2_scale_scalar
self.num_experts,
top_k,
None, # n_group
None, # topk_group
self.intermediate_size, # padded to multiple of 256
0, # local_expert_offset
self.num_experts, # local num experts
None,
self._get_tile_tokens_dim(x, top_k),
1, # routing_method_type, renormalize
True, # do finalize
)[0]
return trtllm_gen_output
"""
if self.use_triton_kernels:
if self.with_bias:
# TODO why we do not put weights on layer?
assert layer.w13_weight is None
assert layer.w2_weight is None
return self.triton_kernel_moe_with_bias_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
w1_pcg=self.w13_precision_config,
w2=self.w2_weight_triton_tensor,
w2_pcg=self.w2_precision_config,
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
topk_output=topk_output,
activation=activation,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)
else:
return self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
)
else:
raise NotImplementedError()
...@@ -272,6 +272,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -272,6 +272,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation=activation, activation=activation,
activation_alpha=activation_alpha, activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit, swiglu_limit=swiglu_limit,
w1_pcg=None,
w2_pcg=None,
) )
else: else:
return self.triton_kernel_moe_forward( return self.triton_kernel_moe_forward(
......
...@@ -25,6 +25,8 @@ from torch import nn ...@@ -25,6 +25,8 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_rank,
get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank, get_moe_tensor_parallel_rank,
get_pp_group, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
...@@ -108,11 +110,15 @@ class GptOssSparseMoeBlock(nn.Module): ...@@ -108,11 +110,15 @@ class GptOssSparseMoeBlock(nn.Module):
experts_type = get_moe_impl_class() experts_type = get_moe_impl_class()
extra_kwargs = {} extra_kwargs = {}
if experts_type.__name__ == "FusedMoE": if experts_type.__name__ == "FusedMoE":
quant_config_name = (
quant_config.get_name() if quant_config is not None else None
)
extra_kwargs = { extra_kwargs = {
"enable_flashinfer_cutlass_moe": global_server_args_dict[ "enable_flashinfer_cutlass_moe": global_server_args_dict[
"enable_flashinfer_cutlass_moe" "enable_flashinfer_cutlass_moe"
], ],
"use_weight_loader_fused": True, # for moe gate_up_proj and down_proj and their bias loading # for moe gate_up_proj and down_proj and their bias loading
"use_weight_loader_fused": quant_config_name != "mxfp4",
} }
self.experts = experts_type( self.experts = experts_type(
num_experts=config.num_local_experts num_experts=config.num_local_experts
...@@ -350,7 +356,6 @@ class GptOssDecoderLayer(nn.Module): ...@@ -350,7 +356,6 @@ class GptOssDecoderLayer(nn.Module):
head_dim=head_dim, head_dim=head_dim,
rms_norm_eps=rms_norm_eps, rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias, attention_bias=attention_bias,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix), prefix=add_prefix("self_attn", prefix),
sliding_window_size=self.sliding_window_size, sliding_window_size=self.sliding_window_size,
layer_type=config.layer_types[layer_id], layer_type=config.layer_types[layer_id],
...@@ -538,7 +543,7 @@ class GptOssForCausalLM(nn.Module): ...@@ -538,7 +543,7 @@ class GptOssForCausalLM(nn.Module):
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, # quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
) )
...@@ -652,11 +657,188 @@ class GptOssForCausalLM(nn.Module): ...@@ -652,11 +657,188 @@ class GptOssForCausalLM(nn.Module):
return weight_mapping return weight_mapping
# TODO beautify code
def load_weights( def load_weights(
self, self,
weights: Iterable[Tuple[str, torch.Tensor]], weights: Iterable[Tuple[str, torch.Tensor]],
is_nextn: bool = False, is_nextn: bool = False,
weight_name_mapping: dict = None, weight_name_mapping: dict = None,
):
quant_config_name = (
self.quant_config.get_name() if self.quant_config is not None else None
)
if quant_config_name != "mxfp4":
self._load_normal_weights(
weights, is_nextn=is_nextn, weight_name_mapping=weight_name_mapping
)
else:
self._load_weights_mxfp4(
weights, is_nextn=is_nextn, weight_name_mapping=weight_name_mapping
)
def _load_weights_mxfp4(self, weights, is_nextn, weight_name_mapping):
mxfp4_weights = []
normal_weights = []
for name, weight in weights:
if (
".experts" in name
and self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
):
mxfp4_weights.append((name, weight))
else:
normal_weights.append((name, weight))
mxfp4_loaded_params = self._load_mxfp4_experts_weights(mxfp4_weights)
self._load_normal_weights(
normal_weights,
is_nextn=is_nextn,
weight_name_mapping=weight_name_mapping,
other_loaded_param_names=mxfp4_loaded_params,
)
def _load_mxfp4_experts_weights(self, weights):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
mxfp4_block = 32
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
intermediate_size = self.config.intermediate_size
intermediate_size_block = intermediate_size // mxfp4_block
per_rank_intermediate_size_block = intermediate_size_block // tp_size
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
# Attention heads per rank
heads_per_rank = self.config.num_attention_heads // tp_size
head_start = tp_rank * heads_per_rank
num_experts = self.config.num_local_experts
for name, weight in weights:
weight = weight.cuda()
if "gate_up_proj_blocks" in name:
# Handle MLP gate and up projection weights
new_name = name.replace("gate_up_proj_blocks", "w13_weight")
# flat weight from (E, 2 * N, block_size, entry_per_block)
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
weight = weight.view(
num_experts, 2 * intermediate_size, -1
).contiguous()
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None,
)
loaded_params.add(new_name)
elif "down_proj_blocks" in name:
# Handle MLP down projection weights
new_name = name.replace("down_proj_blocks", "w2_weight")
# same flatten here, but since 2 mx4 value are packed in 1
# uint8, divide by 2
weight = weight.view(
num_experts, -1, intermediate_size // 2
).contiguous()
narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None,
)
loaded_params.add(new_name)
elif "gate_up_proj_scales" in name:
# Handle MLP gate and up projection weights scale
new_name = name.replace("gate_up_proj_scales", "w13_weight_scale")
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None,
)
loaded_params.add(new_name)
elif "down_proj_scales" in name:
# Handle MLP down projection weights
new_name = name.replace("down_proj_scales", "w2_weight_scale")
narrow_weight = weight[
..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block
]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None,
)
loaded_params.add(new_name)
elif "gate_up_proj_bias" in name:
# Handle MLP gate and up projection biases
new_name = name.replace("gate_up_proj_bias", "w13_weight_bias")
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None,
)
loaded_params.add(new_name)
elif "down_proj_bias" in name:
if get_moe_tensor_parallel_rank() != 0:
weight = torch.zeros_like(weight)
# Handle MLP down projection bias
new_name = name.replace("down_proj_bias", "w2_weight_bias")
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param, weight, weight_name=new_name, shard_id=None, expert_id=None
)
loaded_params.add(new_name)
return loaded_params
def _load_normal_weights(
self,
weights,
is_nextn: bool,
weight_name_mapping: dict,
other_loaded_param_names=[],
): ):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
if is_nextn: if is_nextn:
...@@ -725,15 +907,33 @@ class GptOssForCausalLM(nn.Module): ...@@ -725,15 +907,33 @@ class GptOssForCausalLM(nn.Module):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused( if self.quant_config is not None and (self.quant_config.get_name() == "mxfp4"):
ckpt_gate_up_proj_name="gate_up_proj", expert_params_mapping = (
ckpt_down_proj_name="down_proj", get_moe_impl_class().make_expert_params_mapping_fused_mxfp4(
ckpt_gate_up_proj_bias_name="gate_up_proj_bias", ckpt_gate_up_proj_name="gate_up_proj_blocks",
ckpt_down_proj_bias_name="down_proj_bias", ckpt_down_proj_name="down_proj_blocks",
) ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
ckpt_down_proj_bias_name="down_proj_bias",
ckpt_gate_up_proj_scale_name="gate_up_proj_scales",
ckpt_down_proj_scale_name="down_proj_scales",
)
)
else:
expert_params_mapping = (
get_moe_impl_class().make_expert_params_mapping_fused(
ckpt_gate_up_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj",
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
ckpt_down_proj_bias_name="down_proj_bias",
)
)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
params_checker = {k: False for k, v in params_dict.items()} params_checker = {k: False for k, v in params_dict.items()}
for other_loaded_param_name in other_loaded_param_names:
params_checker[other_loaded_param_name] = True
for name, loaded_weight in weights: for name, loaded_weight in weights:
loaded_weight = _WeightCreator.maybe_materialize(loaded_weight) loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
......
...@@ -464,6 +464,16 @@ class ServerArgs: ...@@ -464,6 +464,16 @@ class ServerArgs:
self.enable_triton_kernel_moe = True self.enable_triton_kernel_moe = True
self.disable_hybrid_swa_memory = True self.disable_hybrid_swa_memory = True
quantization_config = getattr(
self.get_hf_config(), "quantization_config", None
)
if (
quantization_config is not None
and quantization_config.get("quant_method") == "mxfp4"
):
# use bf16 for mxfp4 triton kernels
self.dtype = "bfloat16"
# Set page size # Set page size
if self.page_size is None: if self.page_size is None:
self.page_size = 1 self.page_size = 1
......
...@@ -2124,6 +2124,10 @@ def next_power_of_2(n: int): ...@@ -2124,6 +2124,10 @@ def next_power_of_2(n: int):
return 1 << (n - 1).bit_length() if n > 0 else 1 return 1 << (n - 1).bit_length() if n > 0 else 1
def round_up(x: int, y: int) -> int:
return ((x - 1) // y + 1) * y
setattr(triton, "next_power_of_2", next_power_of_2) setattr(triton, "next_power_of_2", next_power_of_2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment