Unverified Commit fac07c9b authored by strgrb's avatar strgrb Committed by GitHub
Browse files

Support LingV2 model (#10359)


Co-authored-by: default avatar羽癫 <yudian.zy@antgroup.com>
Co-authored-by: default avatarguoyuhong <yuhong.gyh@antgroup.com>
parent b3839a7f
......@@ -13,8 +13,8 @@ from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton import override_config
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe,
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (
get_config_dtype_str,
get_config_file_name,
get_default_config,
......@@ -441,6 +441,15 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in [
"BailingMoEForCausalLM",
"BailingMoeForCausalLM",
"BailingMoeV2ForCausalLM",
]:
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
......
......@@ -141,6 +141,11 @@ class ModelConfig:
if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
self.hf_config.architectures[0] = "MiMoMTP"
if is_draft_model and self.hf_config.architectures[0] in [
"BailingMoeV2ForCausalLM",
"BailingMoeForCausalLM",
]:
self.hf_config.architectures[0] = "BailingMoeForCausalLMNextN"
if (
is_draft_model
and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
......
......@@ -893,6 +893,35 @@ class QKVParallelLinear(ColumnParallelLinear):
)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
def _load_qkv_block_scale(
self, param: BasevLLMParameter, loaded_weight: torch.Tensor
):
block_n, _ = self.quant_method.quant_config.weight_block_size
q_size = self.total_num_heads * self.head_size // block_n
k_size = self.total_num_kv_heads * self.head_size // block_n
v_size = self.total_num_kv_heads * self.head_size // block_n
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, q_size),
("k", q_size, k_size),
("v", q_size + k_size, v_size),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight.narrow(
param.output_dim, shard_offset, shard_size
)
rank_shard_offset = self._get_shard_offset_mapping(shard_id) // block_n
rank_shard_size = self._get_shard_size_mapping(shard_id) // block_n
param.load_qkv_weight(
loaded_weight=loaded_weight_shard,
num_heads=self.num_kv_head_replicas,
shard_id=shard_id,
shard_offset=rank_shard_offset,
shard_size=rank_shard_size,
tp_rank=self.tp_rank,
use_presharded_weights=self.use_presharded_weights,
)
def weight_loader_v2(
self,
param: BasevLLMParameter,
......@@ -906,6 +935,9 @@ class QKVParallelLinear(ColumnParallelLinear):
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_qkv_weight(loaded_weight=loaded_weight)
return
elif isinstance(param, BlockQuantScaleParameter):
self._load_qkv_block_scale(param, loaded_weight)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
}
}
# Copyright 2023-2024 SGLang Team
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/bailing_moe.py
from collections.abc import Iterable
from typing import Optional, Tuple
# coding=utf-8
# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" SGLang BailingMoE model."""
import logging
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import (
LayerCommunicator,
LayerScatterModes,
enable_moe_dense_fully_dp,
)
from sglang.srt.layers.dp_attention import (
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
......@@ -22,356 +54,828 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
LoraConfig = None
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
class BailingAttention(nn.Module):
class BailingMoEMLP(nn.Module):
def __init__(
self,
intermediate_size: int,
config: PretrainedConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: Optional[bool] = True,
prefix: str = "",
):
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
self.total_num_kv_heads = config.num_key_value_heads
assert self.total_num_heads % tp_size == 0
assert self.total_num_kv_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.kv_size = self.num_kv_heads * self.head_dim
self.scale = self.head_dim**-0.5
self.tp_size = tp_size
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=(config.use_bias or config.use_qkv_bias),
quant_config=quant_config,
prefix=add_prefix("query_key_value", prefix),
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
self.gate_up_proj = MergedColumnParallelLinear(
config.hidden_size,
[intermediate_size] * 2,
bias=config.use_bias,
quant_config=quant_config,
prefix=add_prefix("dense", prefix),
prefix=add_prefix("gate_up_proj", prefix),
tp_rank=tp_rank,
tp_size=tp_size,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
self.down_proj = RowParallelLinear(
intermediate_size,
config.hidden_size,
bias=config.use_bias,
reduce_results=reduce_results,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
prefix=add_prefix("down_proj", prefix),
tp_rank=tp_rank,
tp_size=tp_size,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
base=config.rope_theta,
is_neox_style=True,
rope_scaling=config.rope_scaling,
)
if config.hidden_act != "silu":
raise ValueError("Unsupported activation. Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
forward_batch: ForwardBatch,
forward_batch: Optional[ForwardBatch] = None,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if (self.tp_size == 1) and hidden_states.shape[0] == 0:
return hidden_states
q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn(q, k, v, forward_batch)
attn_output, _ = self.dense(context_layer)
return attn_output
gate_up, _ = self.gate_up_proj(hidden_states)
hidden_states = self.act_fn(gate_up)
hidden_states, _ = self.down_proj(hidden_states)
return hidden_states
class BailingMLP(nn.Module):
class BailingMoEGate(nn.Module):
def __init__(
self,
intermediate_size: int,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: Optional[bool] = True,
config,
params_dtype: Optional[torch.dtype] = None,
prefix: str = "",
) -> None:
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
config.hidden_size,
[intermediate_size] * 2,
bias=config.use_bias,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
)
self.down_proj = RowParallelLinear(
intermediate_size,
config.hidden_size,
bias=config.use_bias,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=add_prefix("down_proj", prefix),
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.weight = nn.Parameter(
torch.empty(
(config.num_experts, config.hidden_size),
dtype=self.params_dtype,
),
)
self.act_fn = SiluAndMul()
def forward(self, x):
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
if getattr(config, "moe_router_enable_expert_bias", False):
self.expert_bias = nn.Parameter(
torch.empty((config.num_experts,), dtype=torch.float32),
)
else:
self.expert_bias = None
def forward(self, hidden_states):
logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, None).to(
hidden_states.dtype
)
return logits
class BailingMoE(nn.Module):
class BailingMoESparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_id: int,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
alt_stream: Optional[torch.cuda.Stream] = None,
prefix: str = "",
):
super().__init__()
self.layer_id = layer_id
self.alt_stream = alt_stream
self.tp_size = get_tensor_model_parallel_world_size()
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.hidden_size = config.hidden_size
self.num_shared_experts = config.num_shared_experts
self.norm_expert_prob = config.norm_topk_prob
self.moe_intermediate_size = config.moe_intermediate_size
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
self.score_function = getattr(config, "score_function", None)
if config.hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now."
)
# Gate always runs at half / full precision for now.
router_dtype = getattr(config, "router_dtype", None)
if router_dtype is None:
self.router_dtype = None
elif router_dtype == "fp32":
self.router_dtype = torch.float32
else:
self.router_dtype = torch.bfloat16
# TODO global_server_args_dict["ep_num_redundant_experts"] is used for eplb, not supported now
assert global_server_args_dict["ep_num_redundant_experts"] == 0
# check group topk
self.num_expert_group = getattr(config, "n_group", 0)
self.topk_group = getattr(config, "topk_group", 0)
if self.num_expert_group > 0 or self.topk_group > 0:
assert (
self.num_expert_group > 0
and 0 < self.topk_group <= self.num_expert_group
)
self.use_grouped_topk = True
else:
self.num_expert_group = self.topk_group = None
self.use_grouped_topk = False
self.gate = ReplicatedLinear(
self.hidden_size, self.num_experts, bias=False, quant_config=None
self.num_experts = (
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
)
self.topk = TopK(top_k=self.top_k, renormalize=self.norm_expert_prob)
self.gate = BailingMoEGate(
config=config,
params_dtype=self.router_dtype,
prefix=add_prefix("gate", prefix),
)
self.correction_bias = (
self.gate.expert_bias.data if self.gate.expert_bias is not None else None
)
self.experts = FusedMoE(
if self.score_function is not None:
assert (
self.score_function == "softmax" and self.correction_bias is None
) or (
self.score_function == "sigmoid" and self.correction_bias is not None
), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)"
self.topk = TopK(
top_k=self.top_k,
renormalize=self.norm_topk_prob,
use_grouped_topk=self.use_grouped_topk,
num_expert_group=self.num_expert_group,
# num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=self.topk_group,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
self.experts = get_moe_impl_class()(
num_experts=self.num_experts,
top_k=self.top_k,
layer_id=layer_id,
hidden_size=self.hidden_size,
intermediate_size=self.moe_intermediate_size,
reduce_results=False,
layer_id=self.layer_id,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
)
if self.num_shared_experts > 0:
shared_intermediate_size = (
self.moe_intermediate_size * self.num_shared_experts
)
self.shared_experts = BailingMLP(
intermediate_size=shared_intermediate_size,
# shared expert
if config.num_shared_experts is not None:
if hasattr(config, "moe_shared_expert_intermediate_size"):
intermediate_size = config.moe_shared_expert_intermediate_size
else:
intermediate_size = config.moe_intermediate_size
intermediate_size *= config.num_shared_experts
# disable tp for shared experts when enable deepep moe
self.shared_experts = BailingMoEMLP(
intermediate_size=intermediate_size,
config=config,
quant_config=quant_config,
reduce_results=False,
prefix=add_prefix("shared_experts", prefix),
**(
dict(tp_rank=0, tp_size=1)
if get_moe_a2a_backend().is_deepep()
else {}
),
)
# dispatcher
if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size()
self.deepep_dispatcher = DeepEPDispatcher(
group=parallel_state.get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=self.num_experts,
num_local_experts=config.num_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
async_finish=True, # TODO
return_recv_hook=True,
)
def forward(
self,
hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
if not get_moe_a2a_backend().is_deepep():
return self.forward_normal(hidden_states, use_reduce_scatter)
else:
self.shared_experts = None
return self.forward_deepep(hidden_states, forward_batch)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, self.hidden_size)
def get_moe_weights(self):
return [
x.data
for name, x in self.experts.named_parameters()
if name not in ["correction_bias"]
]
def _forward_shared_experts(self, hidden_states: torch.Tensor):
shared_output = None
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states_flat)
if self.num_shared_experts > 0:
shared_output = self.shared_experts(hidden_states)
return shared_output
router_logits, _ = self.gate(hidden_states_flat)
topk_output = self.topk(hidden_states_flat, router_logits)
final_hidden_states = self.experts(hidden_states_flat, topk_output)
def _forward_router_experts(self, hidden_states: torch.Tensor):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
return self.experts(hidden_states, topk_output)
if shared_output is not None:
def forward_normal_dual_stream(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
shared_output = self._forward_shared_experts(hidden_states)
with torch.cuda.stream(self.alt_stream):
router_output = self._forward_router_experts(hidden_states)
current_stream.wait_stream(self.alt_stream)
return router_output, shared_output
def forward_normal(
self,
hidden_states: torch.Tensor,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_size)
DUAL_STREAM_TOKEN_THRESHOLD = 1024
if (
self.alt_stream is not None
and num_tokens > 0
and num_tokens <= DUAL_STREAM_TOKEN_THRESHOLD
):
final_hidden_states, shared_output = self.forward_normal_dual_stream(
hidden_states
)
else:
shared_output = self._forward_shared_experts(hidden_states)
final_hidden_states = self._forward_router_experts(hidden_states)
if self.num_shared_experts > 0:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
if self.tp_size > 1 and not use_reduce_scatter:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
return final_hidden_states.view(orig_shape)
def forward_deepep(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
shared_output = None
forward_mode = forward_batch.forward_mode
if is_non_idle_and_non_empty(forward_mode, hidden_states):
router_logits = self.gate(hidden_states)
if self.num_shared_experts > 0:
shared_output = self.shared_experts(hidden_states)
topk_weights, topk_idx, _ = self.topk(
hidden_states,
router_logits,
num_token_non_padded=forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else:
topk_idx = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
class BailingMoeBlock(nn.Module):
if self.ep_size > 1:
(
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
num_recv_tokens_per_expert,
seg_indptr,
masked_m,
expected_m,
) = self.deepep_dispatcher.dispatch(
hidden_states,
topk_idx,
topk_weights,
forward_batch=forward_batch,
)
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_batch=forward_batch,
)
if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine(
final_hidden_states,
topk_idx,
topk_weights,
forward_batch=forward_batch,
)
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
return final_hidden_states
class BailingMoEAttention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_id: int,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
):
super().__init__()
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention = BailingAttention(
config, layer_id, quant_config, prefix=add_prefix("attention", prefix)
self.hidden_size = config.hidden_size
self.total_num_heads = config.num_attention_heads
self.total_kv_heads = config.num_key_value_heads
self.dp_size = get_attention_dp_size()
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
assert self.total_num_heads % attn_tp_size == 0
assert self.total_kv_heads % attn_tp_size == 0
assert self.total_num_heads >= self.total_kv_heads
self.num_heads = self.total_num_heads // attn_tp_size
self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
self.q_size = self.head_dim * self.num_heads
self.num_kv_heads = self.total_kv_heads // attn_tp_size
self.kv_size = max(1, self.num_kv_heads * self.head_dim)
self.scale = self.head_dim**-0.5
self.use_qk_norm = getattr(config, "use_qk_norm", False)
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_kv_heads,
bias=(config.use_bias or config.use_qkv_bias),
quant_config=quant_config,
prefix=add_prefix("query_key_value", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
if self.use_qk_norm:
self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=config.use_bias,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=add_prefix("dense", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
self.mlp = BailingMoE(
config=config,
if hasattr(config, "partial_rotary_factor"):
self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)
elif hasattr(config, "rotary_dim"):
self.rotary_dim = config.rotary_dim
else:
self.rotary_dim = self.head_dim
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=config.max_position_embeddings,
base=config.rope_theta,
rope_scaling=config.rope_scaling,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
prefix=add_prefix("attn", prefix),
)
self.alt_stream = alt_stream
def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# overlap qk norm
if self.alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.query_layernorm(q_by_head)
with torch.cuda.stream(self.alt_stream):
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.key_layernorm(k_by_head)
current_stream.wait_stream(self.alt_stream)
else:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.query_layernorm(q_by_head)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.key_layernorm(k_by_head)
q = q_by_head.view(q.shape)
k = k_by_head.view(k.shape)
return q, k
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
residual: Optional[torch.Tensor],
forward_batch: ForwardBatch,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Pre-normalization and residual connection for the attention block
if residual is None:
residual = hidden_states
normed_hidden_states = self.input_layernorm(hidden_states)
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
return hidden_states
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
context_layer = self.attn(q, k, v, forward_batch)
attn_output, _ = self.dense(context_layer)
return attn_output
class BailingMoEBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
self.dp_size = get_attention_dp_size()
self.attention = BailingMoEAttention(
config,
layer_id,
quant_config,
reduce_results=False,
prefix=add_prefix("attention", prefix),
alt_stream=alt_stream,
)
self.layer_id = layer_id
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.is_layer_sparse = self._is_layer_sparse(
config, layer_id=layer_id, is_nextn=False
)
is_previous_layer_sparse = self._is_layer_sparse(
config, layer_id=layer_id - 1, is_nextn=False
)
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
)
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
if self.is_layer_sparse:
self.mlp = BailingMoESparseMoeBlock(
layer_id=layer_id,
config=config,
quant_config=quant_config,
alt_stream=alt_stream,
prefix=add_prefix("mlp", prefix),
)
else:
normed_hidden_states, residual = self.input_layernorm(
hidden_states, residual
if enable_moe_dense_fully_dp():
mlp_tp_rank, mlp_tp_size = 0, 1
else:
mlp_tp_rank, mlp_tp_size = None, None
self.mlp = BailingMoEMLP(
intermediate_size=config.intermediate_size,
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
tp_rank=mlp_tp_rank,
tp_size=mlp_tp_size,
)
attn_output = self.attention(
hidden_states=normed_hidden_states,
position_ids=position_ids,
self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
)
def _is_layer_sparse(
self, config: PretrainedConfig, layer_id: int, is_nextn: bool
) -> bool:
return is_nextn or (
config.num_experts is not None and layer_id >= config.first_k_dense_replace
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
)
hidden_states = self.attention(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
)
# Pre-normalization and residual connection for the MLP block
normed_hidden_states, residual = self.post_attention_layernorm(
attn_output, residual
# For DP with padding, reduce scatter can be used instead of all-reduce.
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
forward_batch
)
mlp_output = self.mlp(normed_hidden_states)
return mlp_output, residual
hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
)
return hidden_states, residual
class BailingMoeModel(nn.Module):
class BailingMoEModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
alt_stream: Optional[torch.cuda.Stream] = None,
prefix: str = "",
):
super().__init__()
self.pp_group = get_pp_group()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_dim = config.hidden_size
if self.pp_group.is_first_rank:
self.word_embeddings = VocabParallelEmbedding(
self.vocab_size,
self.embed_dim,
quant_config=quant_config,
prefix=add_prefix("word_embeddings", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
else:
self.word_embeddings = PPMissingLayer()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
)
self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout)
self.layers = make_layers(
self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers,
lambda idx, prefix: BailingMoeBlock(
config=config,
lambda idx, prefix: BailingMoEBlock(
layer_id=idx,
config=config,
quant_config=quant_config,
prefix=prefix,
alt_stream=alt_stream,
),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix=add_prefix("layers", prefix),
)
self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
if self.pp_group.is_last_rank:
self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer(return_tuple=True)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.word_embeddings(input_ids)
else:
hidden_states = input_embeds
residual = None
else:
hidden_states = input_embeds
residual = None
for layer in self.layers:
hidden_states, residual = layer(
hidden_states,
position_ids,
residual,
forward_batch,
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
forward_batch,
residual,
)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
if not forward_batch.forward_mode.is_idle():
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class BailingMoeForCausalLM(nn.Module):
class BailingMoEForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
prefix: str = "",
):
super().__init__()
self.pp_group = get_pp_group()
self.config = config
self.model = BailingMoeModel(config=config, quant_config=quant_config)
self.lm_head = ParallelLMHead(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
quant_config=quant_config,
self.quant_config = quant_config
alt_stream = torch.cuda.Stream() if _is_cuda else None
self.model = BailingMoEModel(
config,
quant_config,
alt_stream=alt_stream,
prefix=add_prefix("model", ""),
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
# tie_word_embeddings为true,复用tie_word_embeddings,反之是独立的
if config.tie_word_embeddings:
self.lm_head = self.model.word_embeddings
else:
# TODO something wrong with ParallelLMHead with DP attention enabled
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
@property
def start_layer(self):
return self.model.start_layer
@property
def end_layer(self):
return self.model.end_layer
def get_embed_and_head(self):
"""Used by the eagle_worker."""
return self.model.word_embeddings.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
"""Used by the eagle_worker."""
del self.model.word_embeddings.weight
del self.lm_head.weight
self.model.word_embeddings.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
inputs_embeds: Optional[torch.Tensor] = None,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
hidden_states = self.model(
input_ids,
positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if self.pp_group.is_last_rank:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
# compatible with old design
nextn_layer_id = (
0
if self.config.num_hidden_layers == 1
else self.config.num_hidden_layers
)
else:
raise ValueError("num_nextn_predict_layers is not in the config")
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if is_nextn:
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
nextn_spec_weight_names = [
"final_layernorm",
"eh_proj",
"enorm",
"hnorm",
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......@@ -381,39 +885,87 @@ class BailingMoeForCausalLM(nn.Module):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if (
("v_head" in name)
or ("inv_freq" in name)
or (self.config.tie_word_embeddings and "lm_head" in name)
):
continue
if (
hasattr(self.config, "norm_head")
and self.config.norm_head
and "lm_head.weight" in name
):
import torch.nn.functional as F
loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7)
if "model.word_embeddings.weight" == name:
name = "model.embed_tokens.weight"
if is_nextn:
if not name.startswith(nextn_layer_prefix):
continue
# Use shared head and embed weights from target model
if "shared_head.head" in name or "embed_tokens" in name:
continue
is_decoder = True
# For nextn specific weights
for weight_name in nextn_spec_weight_names:
if weight_name in name:
name = name.replace(nextn_layer_prefix, "model")
is_decoder = False
break
# For decoder layer weights
if is_decoder:
name = name.replace(nextn_layer_prefix, "model.decoder")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name in name and "mlp.experts" not in name:
full_param_name = name.replace(weight_name, param_name)
param = params_dict[full_param_name]
param.weight_loader(param, loaded_weight, shard_id)
break
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for p_name, w_name, e_id, s_id in expert_params_mapping:
if w_name in name and "mlp.experts" in name:
full_param_name = name.replace(w_name, p_name)
param = params_dict[full_param_name]
param.weight_loader(
param,
loaded_weight,
full_param_name,
shard_id=s_id,
expert_id=e_id,
)
break
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
......@@ -421,5 +973,30 @@ class BailingMoeForCausalLM(nn.Module):
)
weight_loader(param, loaded_weight)
if not is_nextn:
self.routed_experts_weights_of_layer = {
layer_id: layer.mlp.get_moe_weights()
for layer_id, layer in enumerate(self.model.layers)
if not isinstance(layer, PPMissingLayer)
and isinstance(layer.mlp, BailingMoESparseMoeBlock)
}
@classmethod
def get_model_config_for_expert_location(cls, config):
num_groups = getattr(config, "n_group", 0)
return ModelConfigForExpertLocation(
num_layers=config.num_hidden_layers,
num_logical_experts=config.num_experts,
num_groups=None if num_groups == 0 else num_groups,
)
class BailingMoeForCausalLM(BailingMoEForCausalLM):
pass
class BailingMoeV2ForCausalLM(BailingMoEForCausalLM):
pass
EntryClass = BailingMoeForCausalLM
EntryClass = [BailingMoEForCausalLM, BailingMoeForCausalLM, BailingMoeV2ForCausalLM]
# coding=utf-8
# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" SGLang BailingMoENextN model."""
import logging
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM
from sglang.srt.utils import add_prefix
LoraConfig = None
logger = logging.getLogger(__name__)
class BailingMoEModelNextN(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
logger.warning(
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
)
quant_config = None
self.vocab_size = config.vocab_size
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("word_embeddings", prefix),
)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
self.decoder = BailingMoEBlock(
config,
0,
quant_config=quant_config,
# is_nextn=True,
prefix=add_prefix("decoder", prefix),
)
self.shared_head = nn.Module()
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.word_embeddings(input_ids)
else:
hidden_states = input_embeds
if hidden_states.shape[0] > 0:
hidden_states = self.eh_proj(
torch.cat(
(
self.enorm(hidden_states),
self.hnorm(forward_batch.spec_info.hidden_states),
),
dim=-1,
)
)
residual = None
hidden_states, residual = self.decoder(
positions, hidden_states, forward_batch, residual
)
if not forward_batch.forward_mode.is_idle():
if residual is not None:
hidden_states, _ = self.final_layernorm(hidden_states, residual)
else:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class BailingMoeForCausalLMNextN(BailingMoEForCausalLM):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
if hasattr(self, "determine_num_fused_shared_experts"):
# Asystem has determine_num_fused_shared_experts but theta does not.
self.determine_num_fused_shared_experts("BailingMoeForCausalLMNextN")
self.model = BailingMoEModelNextN(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
super().load_weights(weights, is_nextn=True)
EntryClass = [BailingMoeForCausalLMNextN]
......@@ -754,7 +754,12 @@ class ServerArgs:
)
model_arch = self.get_hf_config().architectures[0]
if model_arch in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"]:
if model_arch in [
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
"BailingMoeV2ForCausalLM",
"BailingMoeV2ForCausalLM",
]:
# Auto set draft_model_path DeepSeek-V3/R1
if self.speculative_draft_model_path is None:
self.speculative_draft_model_path = self.model_path
......@@ -2724,6 +2729,8 @@ def auto_choose_speculative_params(self: ServerArgs):
"DeepseekV3ForCausalLM",
"DeepseekV2ForCausalLM",
"GptOssForCausalLM",
"BailingMoeForCausalLM",
"BailingMoeV2ForCausalLM",
]:
# The default value for deepseek and gpt-oss
return (3, 1, 4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment