Unverified Commit 0ea330ca authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix wrong weight reference in dynamic EPLB (#6818)

parent 27e327b4
...@@ -91,6 +91,7 @@ from sglang.srt.two_batch_overlap import ( ...@@ -91,6 +91,7 @@ from sglang.srt.two_batch_overlap import (
from sglang.srt.utils import ( from sglang.srt.utils import (
BumpAllocator, BumpAllocator,
DeepEPMode, DeepEPMode,
LazyValue,
add_prefix, add_prefix,
bind_or_assign, bind_or_assign,
get_bool_env_var, get_bool_env_var,
...@@ -1661,6 +1662,18 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1661,6 +1662,18 @@ class DeepseekV2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.dp_size = get_local_attention_dp_size() self.dp_size = get_local_attention_dp_size()
self._routed_experts_weights_of_layer = LazyValue(
lambda: {
layer_id: layer.mlp.get_moe_weights()
for layer_id, layer in enumerate(self.model.layers)
if isinstance(layer.mlp, DeepseekV2MoE)
}
)
@property
def routed_experts_weights_of_layer(self):
return self._routed_experts_weights_of_layer.value
def determine_n_share_experts_fusion( def determine_n_share_experts_fusion(
self, architecture: str = "DeepseekV3ForCausalLM" self, architecture: str = "DeepseekV3ForCausalLM"
): ):
...@@ -1873,14 +1886,6 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1873,14 +1886,6 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous()) self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
self_attn.use_deep_gemm_bmm = True self_attn.use_deep_gemm_bmm = True
# TODO support nextn later
if not is_nextn:
self.routed_experts_weights_of_layer = {
layer_id: layer.mlp.get_moe_weights()
for layer_id, layer in enumerate(self.model.layers)
if isinstance(layer.mlp, DeepseekV2MoE)
}
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn: if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"): if hasattr(self.config, "num_nextn_predict_layers"):
......
...@@ -18,15 +18,10 @@ ...@@ -18,15 +18,10 @@
"""Inference-only Qwen3MoE model compatible with HuggingFace weights.""" """Inference-only Qwen3MoE model compatible with HuggingFace weights."""
import logging import logging
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from transformers.configuration_utils import PretrainedConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_pp_group, get_pp_group,
...@@ -811,6 +806,7 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -811,6 +806,7 @@ class Qwen3MoeForCausalLM(nn.Module):
else: else:
logger.warning(f"Parameter {name} not found in params_dict") logger.warning(f"Parameter {name} not found in params_dict")
# TODO mimic deepseek
self.routed_experts_weights_of_layer = { self.routed_experts_weights_of_layer = {
layer_id: self.model.layers[layer_id].mlp.get_moe_weights() layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
for layer_id in range(self.start_layer, self.end_layer) for layer_id in range(self.start_layer, self.end_layer)
......
...@@ -2257,3 +2257,16 @@ except: ...@@ -2257,3 +2257,16 @@ except:
def cpu_has_amx_support(): def cpu_has_amx_support():
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
class LazyValue:
def __init__(self, creator: Callable):
self._creator = creator
self._value = None
@property
def value(self):
if self._creator is not None:
self._value = self._creator()
self._creator = None
return self._value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment