Unverified Commit 659907e3 authored by Zhiyu's avatar Zhiyu Committed by GitHub
Browse files

Enable ModelOpt Llama4 fp8 checkpoint deployment in SGLang (#7129)

parent cb9d91ea
...@@ -649,6 +649,27 @@ class FusedMoE(torch.nn.Module): ...@@ -649,6 +649,27 @@ class FusedMoE(torch.nn.Module):
loaded_weight: torch.tensor, loaded_weight: torch.tensor,
tp_rank: int, tp_rank: int,
): ):
"""Load w2 weights for down projection.
Args:
expert_data: The expert data tensor to load into
shard_dim: The dimension to shard along
shard_id: The shard ID (must be "w2")
loaded_weight: The weight tensor to load from
tp_rank: The tensor parallel rank
"""
if not isinstance(expert_data, torch.Tensor) or not isinstance(
loaded_weight, torch.Tensor
):
raise ValueError("expert_data and loaded_weight must be torch.Tensor")
if expert_data.dim() != 2 or loaded_weight.dim() != 2:
raise ValueError(
f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
)
if shard_id != "w2":
raise ValueError(f"shard_id must be 'w2', got {shard_id}")
# Index the loaded weight for tp sharding. # Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim # down_proj: "RowParallel" so tp sharding on input_dim
...@@ -669,6 +690,10 @@ class FusedMoE(torch.nn.Module): ...@@ -669,6 +690,10 @@ class FusedMoE(torch.nn.Module):
if not self.use_presharded_weights: if not self.use_presharded_weights:
if self.use_triton_kernels: if self.use_triton_kernels:
loaded_weight = loaded_weight.transpose(-2, -1) loaded_weight = loaded_weight.transpose(-2, -1)
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
raise ValueError(
f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
)
loaded_weight = loaded_weight.narrow( loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size shard_dim, shard_size * tp_rank, shard_size
) )
...@@ -795,8 +820,21 @@ class FusedMoE(torch.nn.Module): ...@@ -795,8 +820,21 @@ class FusedMoE(torch.nn.Module):
tp_rank=tp_rank, tp_rank=tp_rank,
) )
return return
if "ModelOpt" in self.quant_method.__class__.__name__: if "ModelOpt" in self.quant_method.__class__.__name__:
if "weight_scale_2" in weight_name or "input_scale" in weight_name: # Determine per-tensor weight scale patterns based on variant
is_fp4_variant = (
"ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__
)
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
per_tensor_conditions = (
"weight_scale_2" in weight_name
if is_fp4_variant
else "weight_scale" in weight_name
) or "input_scale" in weight_name
if per_tensor_conditions:
self._load_per_tensor_weight_scale( self._load_per_tensor_weight_scale(
shard_id=shard_id, shard_id=shard_id,
param=param, param=param,
......
...@@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod ...@@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.utils import ( from sglang.srt.layers.quantization.utils import (
convert_to_channelwise, convert_to_channelwise,
is_layer_skipped, is_layer_skipped,
per_tensor_dequantize,
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
if self.exclude_modules and any( if self.exclude_modules and any(
module in prefix for module in self.exclude_modules module in prefix
or (
prefix.startswith("language_model.")
and module in prefix.removeprefix("language_model.")
)
for module in self.exclude_modules
): ):
return None return None
...@@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig):
if self.kv_cache_quant_method and isinstance(layer, RadixAttention): if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self) return ModelOptFp8KVCacheMethod(self)
# Add MoE support
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, FusedMoE):
return ModelOptFp8MoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): ...@@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
super().__init__(quant_config) super().__init__(quant_config)
class ModelOptFp8MoEMethod:
"""MoE method for ModelOpt FP8.
Supports loading FP8 checkpoints with static weight scale and activation scale.
Args:
quant_config: The ModelOpt quantization config.
"""
def __new__(cls, *args, **kwargs):
"""
Dynamic class composition pattern.
This allows us to effectively "inject" FusedMoEMethodBase as a parent class
at runtime while avoiding circular import issues.
"""
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config: ModelOptFp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
# Use FP8 dtype if checkpoint is serialized, otherwise use the default dtype
weight_dtype = (
torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized
else params_dtype
)
weight_loader = extra_weight_attrs.get("weight_loader")
w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
),
input_dim=2,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight", w13_weight)
w2_weight = ModelWeightParameter(
data=torch.empty(
num_experts, hidden_size, intermediate_size, dtype=weight_dtype
),
input_dim=2,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("w2_weight", w2_weight)
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = PerTensorScaleParameter(
data=torch.full(
(num_experts, 2),
torch.finfo(torch.float32).min,
dtype=torch.float32,
),
weight_loader=weight_loader,
)
w2_weight_scale = PerTensorScaleParameter(
data=torch.full(
(num_experts,), torch.finfo(torch.float32).min, dtype=torch.float32
),
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Set weight loader attributes for scales
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
# INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
w2_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Process FP8 MoE weights after loading from serialized checkpoint.
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
# Handle scale parameters
if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales then dequant and requant each expert.
if layer.w13_weight_scale.dim() == 2: # Shape: (num_experts, 2)
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
# Get the maximum scale across w1 and w3 for each expert
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
# Requantize each expert's weights using the combined scale
# w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
# where the first intermediate_size rows are w1, the next are w3
intermediate_size = layer.w13_weight.shape[1] // 2
for expert_id in range(layer.w13_weight.shape[0]):
start = 0
for shard_id in range(2): # w1 and w3
# Dequantize using the original scale for this shard
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][
start : start + intermediate_size, :
],
layer.w13_weight_scale[expert_id][shard_id],
)
# Requantize using the combined max scale
(
layer.w13_weight[expert_id][
start : start + intermediate_size, :
],
_,
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
start += intermediate_size
# Update the scale parameter to be per-expert instead of per-shard
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
else:
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
layer.w13_input_scale = Parameter(
layer.w13_input_scale.max(), requires_grad=False
)
if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
layer.w2_input_scale = Parameter(
layer.w2_input_scale.max(), requires_grad=False
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
use_fp8_w8a8=True,
per_channel_quant=False, # ModelOpt uses per-tensor quantization
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
no_combine=no_combine,
)
class ModelOptFp4Config(QuantizationConfig): class ModelOptFp4Config(QuantizationConfig):
"""Config class for FP4.""" """Config class for FP4."""
......
import json as json_lib
import logging
import os
from collections.abc import Iterable from collections.abc import Iterable
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
...@@ -19,6 +22,13 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader ...@@ -19,6 +22,13 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cpu from sglang.srt.utils import add_prefix, is_cpu
_is_cpu = is_cpu() _is_cpu = is_cpu()
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
class Llama4ForConditionalGeneration(nn.Module): class Llama4ForConditionalGeneration(nn.Module):
...@@ -37,19 +47,85 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -37,19 +47,85 @@ class Llama4ForConditionalGeneration(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.vision_model = Llama4VisionModel(config.vision_config) # Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
self.multi_modal_projector = Llama4MultiModalProjector(config) self.has_vision = self._has_vision_weights(config)
if not self.has_vision:
logger.warning(
"No vision weights found in checkpoint. Model will run in text-only mode. "
"Multimodal capabilities (image processing) will be unavailable."
)
if self.has_vision:
self.vision_model = Llama4VisionModel(config.vision_config)
self.multi_modal_projector = Llama4MultiModalProjector(config)
else:
self.vision_model = None
self.multi_modal_projector = None
# Initialize the language model # Initialize the language model
from sglang.srt.models.llama4 import Llama4ForCausalLM from sglang.srt.models.llama4 import Llama4ForCausalLM
self.language_model = Llama4ForCausalLM( self.language_model = Llama4ForCausalLM(
config.text_config, config.text_config if hasattr(config, "text_config") else config,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("language_model", prefix), prefix=add_prefix("language_model", prefix),
) )
self.logits_processor = LogitsProcessor(config.text_config) self.logits_processor = LogitsProcessor(
config.text_config if hasattr(config, "text_config") else config
)
def _has_vision_weights(self, config) -> bool:
"""Check if the model has vision components by examining the checkpoint."""
model_path = getattr(config, "_name_or_path", None)
if not model_path:
return False
# Check if this is a local path first
if os.path.isdir(model_path):
index_file = os.path.join(model_path, "model.safetensors.index.json")
if os.path.exists(index_file):
return self._check_vision_weights_in_index(index_file)
# For HuggingFace models, we need to check the actual checkpoint
# The config might say it's multimodal, but the checkpoint might be text-only
try:
# Try to access the HuggingFace cache directory
from huggingface_hub import try_to_load_from_cache
# Check if index file exists in cache
index_file_path = try_to_load_from_cache(
repo_id=model_path,
filename="model.safetensors.index.json",
cache_dir=None,
)
if index_file_path and os.path.exists(index_file_path):
return self._check_vision_weights_in_index(index_file_path)
except Exception:
# If we can't access the cache, fall back to config-based detection
pass
# Fallback, assume text-only
return False
def _check_vision_weights_in_index(self, index_file: str) -> bool:
"""Check if the model.safetensors.index.json contains vision weights."""
try:
with open(index_file, "r") as f:
index_data = json_lib.load(f)
vision_patterns = ["vision_model", "vision_tower", "multi_modal_projector"]
weight_names = index_data.get("weight_map", {}).keys()
return any(
pattern in weight_name
for weight_name in weight_names
for pattern in vision_patterns
)
except (OSError, json_lib.JSONDecodeError, KeyError):
return False
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pattern = MultiModalityDataPaddingPatternMultimodalTokens() pattern = MultiModalityDataPaddingPatternMultimodalTokens()
...@@ -59,6 +135,10 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -59,6 +135,10 @@ class Llama4ForConditionalGeneration(nn.Module):
self, self,
items: List[MultimodalDataItem], items: List[MultimodalDataItem],
) -> torch.Tensor: ) -> torch.Tensor:
# For text-only models, return None or raise an error
if not self.has_vision or self.vision_model is None:
raise ValueError("Vision model not available for text-only checkpoint")
pixel_values = ( pixel_values = (
torch.concat([item.pixel_values for item in items]) torch.concat([item.pixel_values for item in items])
.to(next(self.vision_model.parameters()).device) .to(next(self.vision_model.parameters()).device)
...@@ -79,11 +159,14 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -79,11 +159,14 @@ class Llama4ForConditionalGeneration(nn.Module):
**kwargs: object, **kwargs: object,
) -> torch.Tensor: ) -> torch.Tensor:
# For text-only models, pass None for image_data_embedding_func
image_embedding_func = self.get_image_feature if self.has_vision else None
hs = general_mm_embed_routine( hs = general_mm_embed_routine(
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
language_model=self.language_model, language_model=self.language_model,
image_data_embedding_func=self.get_image_feature, image_data_embedding_func=image_embedding_func,
positions=positions, positions=positions,
) )
...@@ -124,7 +207,6 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -124,7 +207,6 @@ class Llama4ForConditionalGeneration(nn.Module):
return name, loaded_weight return name, loaded_weight
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"), (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
...@@ -137,11 +219,12 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -137,11 +219,12 @@ class Llama4ForConditionalGeneration(nn.Module):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
num_experts = (
self.config.text_config.num_local_experts
if hasattr(self.config, "text_config")
else self.config.num_local_experts
)
num_experts = self.config.text_config.num_local_experts
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
...@@ -150,81 +233,279 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -150,81 +233,279 @@ class Llama4ForConditionalGeneration(nn.Module):
) )
for name, loaded_weight in weights: for name, loaded_weight in weights:
if not "vision" in name: if self._should_skip_weight(name):
continue
name = self._transform_weight_name(name)
if "vision" not in name:
name, loaded_weight = self.permute_qk_weight_for_rotary( name, loaded_weight = self.permute_qk_weight_for_rotary(
name, loaded_weight name, loaded_weight
) )
for param_name, weight_name, shard_id in stacked_params_mapping: if self._handle_scale_remapping(name, params_dict):
if weight_name not in name: continue
continue
if self._handle_stacked_params(
if "vision" in name: name, loaded_weight, stacked_params_mapping, params_dict
continue ):
name = name.replace(weight_name, param_name) continue
param = params_dict[name]
weight_loader = param.weight_loader if self._handle_expert_weights(
weight_loader(param, loaded_weight, shard_id) name, loaded_weight, expert_params_mapping, params_dict, num_experts
break ):
continue
self._handle_default_weight(name, loaded_weight, params_dict)
def _should_skip_weight(self, name: str) -> bool:
"""Check if we should skip loading this weight."""
return "vision" in name and not self.has_vision
def _transform_weight_name(self, name: str) -> str:
"""Transform weight name by adding language_model prefix if needed."""
if (
not name.startswith("language_model.")
and "vision" not in name
and "multi_modal_projector" not in name
):
return f"language_model.{name}"
return name
def _handle_scale_remapping(self, name: str, params_dict: dict) -> bool:
"""Handle scale parameter remapping. Returns True if handled."""
if "scale" in name and "expert" not in name:
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
return remapped_name is None
return False
def _handle_stacked_params(
self,
name: str,
loaded_weight: torch.Tensor,
stacked_params_mapping: list,
params_dict: dict,
) -> bool:
"""Handle stacked parameter loading. Returns True if handled."""
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name in name and "vision" not in name:
transformed_name = name.replace(weight_name, param_name)
param = params_dict[transformed_name]
param.weight_loader(param, loaded_weight, shard_id)
return True
return False
def _handle_expert_weights(
self,
name: str,
loaded_weight: torch.Tensor,
expert_params_mapping: list,
params_dict: dict,
num_experts: int,
) -> bool:
"""Handle expert weight loading for MoE (Mixture of Experts) layers.
Args:
name: Parameter name from the checkpoint
loaded_weight: The weight tensor to be loaded
expert_params_mapping: Mapping of parameter names to expert configurations
params_dict: Dictionary of model parameters
num_experts: Total number of experts in the MoE layer
Returns:
bool: True if the parameter was handled (is an expert parameter), False otherwise
"""
if ".experts" not in name:
return False
if "experts.gate_up_proj" not in name and "experts.down_proj" not in name:
return self._handle_other_expert_params(
name, loaded_weight, expert_params_mapping, params_dict
)
if "scale" in name:
return self._handle_expert_scale_params(
name, loaded_weight, params_dict, num_experts
)
else:
return self._handle_expert_weight_params(
name, loaded_weight, params_dict, num_experts
)
def _handle_other_expert_params(
self,
name: str,
loaded_weight: torch.Tensor,
expert_params_mapping: list,
params_dict: dict,
) -> bool:
"""Handle expert parameters that are not gate_up_proj or down_proj weights.
Args:
name: Parameter name from the checkpoint
loaded_weight: The weight tensor to be loaded
expert_params_mapping: List of tuples mapping checkpoint names to model parameters
params_dict: Dictionary of model parameters
Returns:
bool: True if parameter was found and handled, False otherwise
"""
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
if weight_name in name:
transformed_name = name.replace(weight_name, param_name)
param = params_dict[transformed_name]
param.weight_loader(
param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id
)
return True
return False
def _transform_expert_name(
self, name: str, is_weight: bool = False
) -> Tuple[str, str, List[str]]:
"""Transform expert parameter name and get shard information.
Args:
name: The original parameter name
is_weight: Whether this is a weight parameter (adds _weight suffix)
Returns:
Tuple of (transformed_name, shard_id, shard_id_list)
"""
suffix = "_weight" if is_weight else ""
if ".gate_up_proj" in name:
transformed_name = name.replace(
".experts.gate_up_proj", f".experts.w13{suffix}"
)
shard_id = "w13"
shard_id_list = ["w1", "w3"]
else: # down_proj
transformed_name = name.replace(
".experts.down_proj", f".experts.w2{suffix}"
)
shard_id = "w2"
shard_id_list = ["w2"]
return transformed_name, shard_id, shard_id_list
def _handle_expert_scale_params(
self,
name: str,
loaded_weight: torch.Tensor,
params_dict: dict,
num_experts: int,
) -> bool:
"""Handle quantization scale parameters for expert weights.
Args:
name: Parameter name containing scale information
loaded_weight: Scale tensor to be loaded
params_dict: Dictionary of model parameters
num_experts: Total number of experts for broadcast operations
Returns:
bool: True (always handles scale parameters)
"""
import re
# Check if this matches the expert parameter pattern: experts.{expert_id}.{param_name}
expert_match = re.search(r"experts\.(\d+)\.", name)
# Transform name
transformed_name, _, _ = self._transform_expert_name(name)
if transformed_name not in params_dict:
return True
param = params_dict[transformed_name]
# Handle scale parameters
if expert_match:
# If we have a specific expert ID, only load for that expert
expert_id = int(expert_match.group(1))
# For scale parameters, we can directly set the value
param.data[expert_id] = loaded_weight
else:
# No expert ID found - this is a single scale for all experts
# Load the same scale for all experts
for expert_id in range(num_experts):
param.data[expert_id] = loaded_weight
return True
def _handle_expert_weight_params(
self,
name: str,
loaded_weight: torch.Tensor,
params_dict: dict,
num_experts: int,
) -> bool:
"""Handle actual weight tensors for expert layers (gate_up_proj and down_proj).
Args:
name: Parameter name (should contain gate_up_proj or down_proj)
loaded_weight: Weight tensor(s) to be loaded
params_dict: Dictionary of model parameters
num_experts: Total number of experts for tensor distribution
Returns:
bool: True (always handles weight parameters)
"""
# Transform name and get shard info
transformed_name, _, shard_id_list = self._transform_expert_name(
name, is_weight=True
)
if ".gate_up_proj" in name:
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
else: # down_proj
loaded_weight_list = [loaded_weight]
for param_name, weight_chunk, shard_id in zip(
[transformed_name] * len(shard_id_list), loaded_weight_list, shard_id_list
):
if param_name not in params_dict:
continue
param = params_dict[param_name]
weight_loader = param.weight_loader
# Handle the case where loaded_weight might be a single tensor for all experts
if weight_chunk.dim() == 2:
# Single tensor case - load for all experts
for expert_id in range(num_experts):
weight_loader(
param,
weight_chunk.T,
param_name,
shard_id=shard_id,
expert_id=expert_id,
)
else: else:
if ".experts" in name: # Multiple experts case - load each expert's weights
# NOTE: llama4 fp8 has different weight format for experts for expert_id in range(num_experts):
if ( weight_loader(
"experts.gate_up_proj" not in name param,
and "experts.down_proj" not in name weight_chunk[expert_id].T,
): param_name,
for mapping in expert_params_mapping: shard_id=shard_id,
param_name, weight_name, expert_id, shard_id = mapping expert_id=expert_id,
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
if ".gate_up_proj" in name:
name_list = [
name.replace(
".experts.gate_up_proj", ".experts.w13_weight"
)
] * 2
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
shard_id_list = ["w1", "w3"]
else:
name_list = [
name.replace(".experts.down_proj", ".experts.w2_weight")
]
shard_id_list = ["w2"]
loaded_weight_list = [loaded_weight]
for name, loaded_weight, shard_id in zip(
name_list, loaded_weight_list, shard_id_list
):
param = params_dict[name]
weight_loader = param.weight_loader
for expert_id in range(num_experts):
weight_loader(
param,
loaded_weight[expert_id].T,
name,
shard_id=shard_id,
expert_id=expert_id,
)
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
) )
weight_loader(param, loaded_weight)
return True
def _handle_default_weight(
self, name: str, loaded_weight: torch.Tensor, params_dict: dict
):
"""Handle default weight loading."""
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
if hasattr(self.language_model, "set_eagle3_layers_to_capture"): if hasattr(self.language_model, "set_eagle3_layers_to_capture"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment