Unverified Commit bf0f448f authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[2/N] MoE Refactor: Unify weight loader and quant methods (#8397)

parent 36d6f0ba
...@@ -77,6 +77,7 @@ class FusedMoE(torch.nn.Module): ...@@ -77,6 +77,7 @@ class FusedMoE(torch.nn.Module):
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
enable_flashinfer_cutlass_moe: Optional[bool] = False, enable_flashinfer_cutlass_moe: Optional[bool] = False,
enable_ep_moe: Optional[bool] = False, enable_ep_moe: Optional[bool] = False,
skip_quant: Optional[bool] = False,
): ):
super().__init__() super().__init__()
...@@ -99,9 +100,6 @@ class FusedMoE(torch.nn.Module): ...@@ -99,9 +100,6 @@ class FusedMoE(torch.nn.Module):
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
if enable_ep_moe: if enable_ep_moe:
assert (
self.enable_flashinfer_cutlass_moe
), "FusedMoE only supports EP with --enable-flashinfer-cutlass-moe"
self.ep_size = self.tp_size self.ep_size = self.tp_size
self.ep_rank = self.tp_rank self.ep_rank = self.tp_rank
self.tp_size = 1 self.tp_size = 1
...@@ -110,16 +108,16 @@ class FusedMoE(torch.nn.Module): ...@@ -110,16 +108,16 @@ class FusedMoE(torch.nn.Module):
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32) self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
# Create a expert map for the local experts # Create a expert map for the local experts
assert num_experts % self.ep_size == 0 assert num_experts % self.ep_size == 0
self.local_num_experts = num_experts // self.ep_size self.num_local_experts = num_experts // self.ep_size
self.expert_map[ self.expert_map[
self.ep_rank self.ep_rank
* self.local_num_experts : (self.ep_rank + 1) * self.num_local_experts : (self.ep_rank + 1)
* self.local_num_experts * self.num_local_experts
] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu") ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
else: else:
self.ep_size = 1 self.ep_size = 1
self.ep_rank = 0 self.ep_rank = 0
self.local_num_experts = num_experts self.num_local_experts = num_experts
self.routed_scaling_factor = routed_scaling_factor self.routed_scaling_factor = routed_scaling_factor
assert intermediate_size % self.tp_size == 0 assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size self.intermediate_size_per_partition = intermediate_size // self.tp_size
...@@ -134,6 +132,9 @@ class FusedMoE(torch.nn.Module): ...@@ -134,6 +132,9 @@ class FusedMoE(torch.nn.Module):
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"] not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
) )
if skip_quant:
return
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod( self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels self.use_triton_kernels
...@@ -149,7 +150,7 @@ class FusedMoE(torch.nn.Module): ...@@ -149,7 +150,7 @@ class FusedMoE(torch.nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.quant_method.create_weights( self.quant_method.create_weights(
layer=self, layer=self,
num_experts=self.local_num_experts, num_experts=self.num_local_experts,
hidden_size=hidden_size, hidden_size=hidden_size,
# FIXME: figure out which intermediate_size to use # FIXME: figure out which intermediate_size to use
intermediate_size=self.intermediate_size_per_partition, intermediate_size=self.intermediate_size_per_partition,
...@@ -378,6 +379,23 @@ class FusedMoE(torch.nn.Module): ...@@ -378,6 +379,23 @@ class FusedMoE(torch.nn.Module):
if expert_id == -1: if expert_id == -1:
return return
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
def _weight_loader_impl(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
# TP rank is set to 0 if EP is enabled # TP rank is set to 0 if EP is enabled
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank() tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
...@@ -398,6 +416,10 @@ class FusedMoE(torch.nn.Module): ...@@ -398,6 +416,10 @@ class FusedMoE(torch.nn.Module):
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}." f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
) )
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
if getattr(self, "use_flashinfer_trtllm_moe", False):
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported] WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
# Fetch the dim to shard the parameter/loaded weight # Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever # based on the shard id. This will be whatever
...@@ -605,37 +627,3 @@ class FusedMoE(torch.nn.Module): ...@@ -605,37 +627,3 @@ class FusedMoE(torch.nn.Module):
("w3", ckpt_up_proj_name), ("w3", ckpt_up_proj_name),
] ]
] ]
def _load_fp8_scale(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
param_data = param.data
# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if (
param_data[expert_id] != 1
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
):
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}"
)
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale" in weight_name:
# If we are in merged column case (gate_up_proj)
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
else:
param_data[expert_id] = loaded_weight
...@@ -172,6 +172,7 @@ class Fp8Config(QuantizationConfig): ...@@ -172,6 +172,7 @@ class Fp8Config(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
...@@ -180,6 +181,8 @@ class Fp8Config(QuantizationConfig): ...@@ -180,6 +181,8 @@ class Fp8Config(QuantizationConfig):
return Fp8LinearMethod(self) return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self) return Fp8MoEMethod(self)
elif isinstance(layer, EPMoE):
return Fp8EPMoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -791,11 +794,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -791,11 +794,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# merged w13 weights and generate a single scaling factor. # merged w13 weights and generate a single scaling factor.
layer.w13_weight_scale = torch.nn.Parameter( layer.w13_weight_scale = torch.nn.Parameter(
torch.ones( torch.ones(
layer.num_experts, dtype=torch.float32, device=w13_weight.device layer.num_local_experts,
dtype=torch.float32,
device=w13_weight.device,
), ),
requires_grad=False, requires_grad=False,
) )
for expert in range(layer.num_experts): for expert in range(layer.num_local_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
) )
...@@ -871,7 +876,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -871,7 +876,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert layer.w13_weight_scale is not None assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts): for expert_id in range(layer.num_local_experts):
start = 0 start = 0
for shard_id in range(2): for shard_id in range(2):
dq_weight = per_tensor_dequantize( dq_weight = per_tensor_dequantize(
...@@ -914,7 +919,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -914,7 +919,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert layer.w13_weight_scale is not None assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts): for expert_id in range(layer.num_local_experts):
start = 0 start = 0
max_w13_scale_fp8 = max_w13_scales[expert_id] max_w13_scale_fp8 = max_w13_scales[expert_id]
for shard_id in range(2): for shard_id in range(2):
...@@ -931,7 +936,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -931,7 +936,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling # special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post # optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
for expert_id in range(layer.num_experts): for expert_id in range(layer.num_local_experts):
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id] layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id] layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
...@@ -979,8 +984,23 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -979,8 +984,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
if isinstance(layer, EPMoE):
layer.w13_weight_scale = (
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
)
layer.w2_weight_scale = (
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
)
return layer.run_moe(
hidden_states=x,
topk_output=topk_output,
)
if use_intel_amx_backend(layer): if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
...@@ -1138,248 +1158,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1138,248 +1158,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
return None return None
class Fp8EPMoEMethod(Fp8MoEMethod):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
def create_weights(
self,
layer: Module,
num_experts_per_partition: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
tp_size = get_tensor_model_parallel_world_size()
if self.block_quant:
block_n, block_k = (
self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1],
)
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_n = {block_n}."
)
if tp_size > 1:
# Required by row parallel
if intermediate_size % block_k != 0:
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_k = {block_k}."
)
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
2 * intermediate_size,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
hidden_size,
intermediate_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
if self.block_quant:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts_per_partition,
2 * ((intermediate_size + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts_per_partition,
(hidden_size + block_n - 1) // block_n,
(intermediate_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic"
else:
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
if self.block_quant
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
layer.w13_weight_scale = torch.nn.Parameter(
torch.ones(
layer.num_experts_per_partition,
dtype=torch.float32,
device=w13_weight.device,
),
requires_grad=False,
)
for expert in range(layer.num_experts_per_partition):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else:
if self.quant_config.activation_scheme == "static":
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
layer.w13_weight_scale = torch.nn.Parameter(
torch.max(layer.w13_weight_scale, dim=1).values,
requires_grad=False,
)
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_fp8_fnuz:
# activation_scheme: dynamic
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w13_weight,
weight_scale=layer.w13_weight_scale_inv,
input_scale=None,
)
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w2_weight,
weight_scale=layer.w2_weight_scale_inv,
input_scale=None,
)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(
w13_weight, requires_grad=False
)
layer.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
layer.w13_input_scale = None
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
layer.w2_input_scale = None
if _use_aiter:
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
layer.w2_weight = torch.nn.Parameter(
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
return
def apply(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
) -> torch.Tensor:
raise NotImplementedError
class Fp8KVCacheMethod(BaseKVCacheMethod): class Fp8KVCacheMethod(BaseKVCacheMethod):
""" """
Supports loading kv-cache scaling factors from FP8 checkpoints. Supports loading kv-cache scaling factors from FP8 checkpoints.
......
...@@ -24,6 +24,7 @@ from sglang.srt.utils import ( ...@@ -24,6 +24,7 @@ from sglang.srt.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
...@@ -194,6 +195,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -194,6 +195,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
if isinstance(layer, EPMoE):
return layer.run_moe(
hidden_states=x,
topk_output=topk_output,
)
return self.forward( return self.forward(
x=x, x=x,
layer=layer, layer=layer,
...@@ -354,69 +364,3 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -354,69 +364,3 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
raise NotImplementedError("The TPU backend currently does not support MoE.") raise NotImplementedError("The TPU backend currently does not support MoE.")
forward_native = forward_cpu forward_native = forward_cpu
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
def create_weights(
self,
layer: torch.nn.Module,
num_experts_per_partition: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
2 * intermediate_size,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
hidden_size,
intermediate_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# scale
layer.register_parameter("w13_input_scale", None)
layer.register_parameter("w13_weight_scale", None)
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
w2_input_scale = torch.nn.Parameter(
ones_tensor,
requires_grad=False,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
ones_tensor,
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
) -> torch.Tensor:
raise NotImplementedError
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch import torch
from torch.nn import Module from torch.nn import Module
...@@ -17,6 +17,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod ...@@ -17,6 +17,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -84,13 +87,14 @@ class W4AFp8Config(QuantizationConfig): ...@@ -84,13 +87,14 @@ class W4AFp8Config(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers): if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
return Fp8LinearMethod(self) return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, EPMoE):
return W4AFp8MoEMethod(self) return W4AFp8MoEMethod(self)
return None return None
...@@ -105,8 +109,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -105,8 +109,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def create_weights( def create_weights(
self, self,
layer: Module, layer: EPMoE,
num_experts_per_partition: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
...@@ -117,7 +121,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -117,7 +121,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts_per_partition, num_experts,
intermediate_size * 2, intermediate_size * 2,
hidden_size // 2, hidden_size // 2,
dtype=torch.int8, dtype=torch.int8,
...@@ -130,7 +134,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -130,7 +134,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
# down_proj (row parallel) # down_proj (row parallel)
w2_weight = torch.nn.Parameter( w2_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts_per_partition, num_experts,
hidden_size, hidden_size,
intermediate_size // 2, intermediate_size // 2,
dtype=torch.int8, dtype=torch.int8,
...@@ -142,7 +146,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -142,7 +146,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.zeros( torch.zeros(
num_experts_per_partition, num_experts,
2 * intermediate_size, 2 * intermediate_size,
hidden_size // self.quant_config.group_size, hidden_size // self.quant_config.group_size,
dtype=torch.float32, dtype=torch.float32,
...@@ -154,7 +158,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -154,7 +158,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
w2_weight_scale = torch.nn.Parameter( w2_weight_scale = torch.nn.Parameter(
torch.zeros( torch.zeros(
num_experts_per_partition, num_experts,
hidden_size, hidden_size,
intermediate_size // self.quant_config.group_size, intermediate_size // self.quant_config.group_size,
dtype=torch.float32, dtype=torch.float32,
...@@ -166,14 +170,14 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -166,14 +170,14 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
# Input scales # Input scales
w13_input_scale = torch.nn.Parameter( w13_input_scale = torch.nn.Parameter(
torch.ones((num_experts_per_partition, 2), dtype=torch.bfloat16), torch.ones((num_experts, 2), dtype=torch.bfloat16),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w13_input_scale", w13_input_scale) layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs) set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter( w2_input_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.bfloat16), torch.ones(num_experts, dtype=torch.bfloat16),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w2_input_scale", w2_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale)
...@@ -183,25 +187,25 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -183,25 +187,25 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
device = layer.w13_weight.device device = layer.w13_weight.device
self.a_strides1 = torch.full( self.a_strides1 = torch.full(
(num_experts_per_partition, 3), (num_experts, 3),
hidden_size, hidden_size,
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
) )
self.c_strides1 = torch.full( self.c_strides1 = torch.full(
(num_experts_per_partition, 3), (num_experts, 3),
2 * intermediate_size, 2 * intermediate_size,
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
) )
self.a_strides2 = torch.full( self.a_strides2 = torch.full(
(num_experts_per_partition, 3), (num_experts, 3),
intermediate_size, intermediate_size,
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
) )
self.c_strides2 = torch.full( self.c_strides2 = torch.full(
(num_experts_per_partition, 3), (num_experts, 3),
hidden_size, hidden_size,
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
...@@ -212,13 +216,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -212,13 +216,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
self.s_strides2 = self.c_strides2 self.s_strides2 = self.c_strides2
self.expert_offsets = torch.empty( self.expert_offsets = torch.empty(
(num_experts_per_partition + 1), dtype=torch.int32, device=device (num_experts + 1), dtype=torch.int32, device=device
) )
self.problem_sizes1 = torch.empty( self.problem_sizes1 = torch.empty(
(num_experts_per_partition, 3), dtype=torch.int32, device=device (num_experts, 3), dtype=torch.int32, device=device
) )
self.problem_sizes2 = torch.empty( self.problem_sizes2 = torch.empty(
(num_experts_per_partition, 3), dtype=torch.int32, device=device (num_experts, 3), dtype=torch.int32, device=device
) )
return return
...@@ -266,3 +270,50 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -266,3 +270,50 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
[w2_input_scale_max], dtype=dtype, device=device [w2_input_scale_max], dtype=dtype, device=device
) )
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False) layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
def apply(
self,
layer: EPMoE,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
) -> torch.Tensor:
# TODO(ch-wan): move it out of this class
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
topk_ids, topk_weights, _ = topk_output
local_topk_ids = topk_ids
if layer.expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(
layer.expert_map[topk_ids] != layer.num_experts,
layer.expert_map[topk_ids],
layer.num_experts,
)
return cutlass_w4a8_moe(
layer.start_expert_id,
layer.end_expert_id,
layer.num_experts,
hidden_states,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
topk_weights,
topk_ids,
local_topk_ids,
self.a_strides1,
self.b_strides1,
self.c_strides1,
self.a_strides2,
self.b_strides2,
self.c_strides2,
self.s_strides13,
self.s_strides2,
self.expert_offsets,
self.problem_sizes1,
self.problem_sizes2,
layer.w13_input_scale,
layer.w2_input_scale,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment