"src/vscode:/vscode.git/clone" did not exist on "e70138bbe42f93d44fbf8c4704fbbde1cd1fdbc9"
Unverified Commit 659907e3 authored by Zhiyu's avatar Zhiyu Committed by GitHub
Browse files

Enable ModelOpt Llama4 fp8 checkpoint deployment in SGLang (#7129)

parent cb9d91ea
...@@ -649,6 +649,27 @@ class FusedMoE(torch.nn.Module): ...@@ -649,6 +649,27 @@ class FusedMoE(torch.nn.Module):
loaded_weight: torch.tensor, loaded_weight: torch.tensor,
tp_rank: int, tp_rank: int,
): ):
"""Load w2 weights for down projection.
Args:
expert_data: The expert data tensor to load into
shard_dim: The dimension to shard along
shard_id: The shard ID (must be "w2")
loaded_weight: The weight tensor to load from
tp_rank: The tensor parallel rank
"""
if not isinstance(expert_data, torch.Tensor) or not isinstance(
loaded_weight, torch.Tensor
):
raise ValueError("expert_data and loaded_weight must be torch.Tensor")
if expert_data.dim() != 2 or loaded_weight.dim() != 2:
raise ValueError(
f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
)
if shard_id != "w2":
raise ValueError(f"shard_id must be 'w2', got {shard_id}")
# Index the loaded weight for tp sharding. # Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim # down_proj: "RowParallel" so tp sharding on input_dim
...@@ -669,6 +690,10 @@ class FusedMoE(torch.nn.Module): ...@@ -669,6 +690,10 @@ class FusedMoE(torch.nn.Module):
if not self.use_presharded_weights: if not self.use_presharded_weights:
if self.use_triton_kernels: if self.use_triton_kernels:
loaded_weight = loaded_weight.transpose(-2, -1) loaded_weight = loaded_weight.transpose(-2, -1)
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
raise ValueError(
f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
)
loaded_weight = loaded_weight.narrow( loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size shard_dim, shard_size * tp_rank, shard_size
) )
...@@ -795,8 +820,21 @@ class FusedMoE(torch.nn.Module): ...@@ -795,8 +820,21 @@ class FusedMoE(torch.nn.Module):
tp_rank=tp_rank, tp_rank=tp_rank,
) )
return return
if "ModelOpt" in self.quant_method.__class__.__name__: if "ModelOpt" in self.quant_method.__class__.__name__:
if "weight_scale_2" in weight_name or "input_scale" in weight_name: # Determine per-tensor weight scale patterns based on variant
is_fp4_variant = (
"ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__
)
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
per_tensor_conditions = (
"weight_scale_2" in weight_name
if is_fp4_variant
else "weight_scale" in weight_name
) or "input_scale" in weight_name
if per_tensor_conditions:
self._load_per_tensor_weight_scale( self._load_per_tensor_weight_scale(
shard_id=shard_id, shard_id=shard_id,
param=param, param=param,
......
...@@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod ...@@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.utils import ( from sglang.srt.layers.quantization.utils import (
convert_to_channelwise, convert_to_channelwise,
is_layer_skipped, is_layer_skipped,
per_tensor_dequantize,
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
if self.exclude_modules and any( if self.exclude_modules and any(
module in prefix for module in self.exclude_modules module in prefix
or (
prefix.startswith("language_model.")
and module in prefix.removeprefix("language_model.")
)
for module in self.exclude_modules
): ):
return None return None
...@@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig):
if self.kv_cache_quant_method and isinstance(layer, RadixAttention): if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self) return ModelOptFp8KVCacheMethod(self)
# Add MoE support
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, FusedMoE):
return ModelOptFp8MoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): ...@@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
super().__init__(quant_config) super().__init__(quant_config)
class ModelOptFp8MoEMethod:
"""MoE method for ModelOpt FP8.
Supports loading FP8 checkpoints with static weight scale and activation scale.
Args:
quant_config: The ModelOpt quantization config.
"""
def __new__(cls, *args, **kwargs):
"""
Dynamic class composition pattern.
This allows us to effectively "inject" FusedMoEMethodBase as a parent class
at runtime while avoiding circular import issues.
"""
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config: ModelOptFp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
# Use FP8 dtype if checkpoint is serialized, otherwise use the default dtype
weight_dtype = (
torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized
else params_dtype
)
weight_loader = extra_weight_attrs.get("weight_loader")
w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
),
input_dim=2,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight", w13_weight)
w2_weight = ModelWeightParameter(
data=torch.empty(
num_experts, hidden_size, intermediate_size, dtype=weight_dtype
),
input_dim=2,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("w2_weight", w2_weight)
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = PerTensorScaleParameter(
data=torch.full(
(num_experts, 2),
torch.finfo(torch.float32).min,
dtype=torch.float32,
),
weight_loader=weight_loader,
)
w2_weight_scale = PerTensorScaleParameter(
data=torch.full(
(num_experts,), torch.finfo(torch.float32).min, dtype=torch.float32
),
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Set weight loader attributes for scales
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
# INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
w2_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Process FP8 MoE weights after loading from serialized checkpoint.
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
# Handle scale parameters
if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales then dequant and requant each expert.
if layer.w13_weight_scale.dim() == 2: # Shape: (num_experts, 2)
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
# Get the maximum scale across w1 and w3 for each expert
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
# Requantize each expert's weights using the combined scale
# w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
# where the first intermediate_size rows are w1, the next are w3
intermediate_size = layer.w13_weight.shape[1] // 2
for expert_id in range(layer.w13_weight.shape[0]):
start = 0
for shard_id in range(2): # w1 and w3
# Dequantize using the original scale for this shard
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][
start : start + intermediate_size, :
],
layer.w13_weight_scale[expert_id][shard_id],
)
# Requantize using the combined max scale
(
layer.w13_weight[expert_id][
start : start + intermediate_size, :
],
_,
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
start += intermediate_size
# Update the scale parameter to be per-expert instead of per-shard
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
else:
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
layer.w13_input_scale = Parameter(
layer.w13_input_scale.max(), requires_grad=False
)
if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
layer.w2_input_scale = Parameter(
layer.w2_input_scale.max(), requires_grad=False
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
use_fp8_w8a8=True,
per_channel_quant=False, # ModelOpt uses per-tensor quantization
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
no_combine=no_combine,
)
class ModelOptFp4Config(QuantizationConfig): class ModelOptFp4Config(QuantizationConfig):
"""Config class for FP4.""" """Config class for FP4."""
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment