"docs/source/tutorials/syncbn.rst" did not exist on "25985c31fbad7eaff58dbd6575414fad06cb4b42"
Unverified Commit fac07c9b authored by strgrb's avatar strgrb Committed by GitHub
Browse files

Support LingV2 model (#10359)


Co-authored-by: default avatar羽癫 <yudian.zy@antgroup.com>
Co-authored-by: default avatarguoyuhong <yuhong.gyh@antgroup.com>
parent b3839a7f
...@@ -13,8 +13,8 @@ from ray.experimental.tqdm_ray import tqdm ...@@ -13,8 +13,8 @@ from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton import override_config from sglang.srt.layers.moe.fused_moe_triton import override_config
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
fused_moe, from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (
get_config_dtype_str, get_config_dtype_str,
get_config_file_name, get_config_file_name,
get_default_config, get_default_config,
...@@ -441,6 +441,15 @@ def main(args: argparse.Namespace): ...@@ -441,6 +441,15 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in [
"BailingMoEForCausalLM",
"BailingMoeForCausalLM",
"BailingMoeV2ForCausalLM",
]:
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["Glm4MoeForCausalLM"]: elif config.architectures[0] in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts E = config.n_routed_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
......
...@@ -141,6 +141,11 @@ class ModelConfig: ...@@ -141,6 +141,11 @@ class ModelConfig:
if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM": if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
self.hf_config.architectures[0] = "MiMoMTP" self.hf_config.architectures[0] = "MiMoMTP"
if is_draft_model and self.hf_config.architectures[0] in [
"BailingMoeV2ForCausalLM",
"BailingMoeForCausalLM",
]:
self.hf_config.architectures[0] = "BailingMoeForCausalLMNextN"
if ( if (
is_draft_model is_draft_model
and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM" and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
......
...@@ -893,6 +893,35 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -893,6 +893,35 @@ class QKVParallelLinear(ColumnParallelLinear):
) )
self.weight_loader_v2(param, loaded_weight_shard, shard_id) self.weight_loader_v2(param, loaded_weight_shard, shard_id)
def _load_qkv_block_scale(
self, param: BasevLLMParameter, loaded_weight: torch.Tensor
):
block_n, _ = self.quant_method.quant_config.weight_block_size
q_size = self.total_num_heads * self.head_size // block_n
k_size = self.total_num_kv_heads * self.head_size // block_n
v_size = self.total_num_kv_heads * self.head_size // block_n
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, q_size),
("k", q_size, k_size),
("v", q_size + k_size, v_size),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight.narrow(
param.output_dim, shard_offset, shard_size
)
rank_shard_offset = self._get_shard_offset_mapping(shard_id) // block_n
rank_shard_size = self._get_shard_size_mapping(shard_id) // block_n
param.load_qkv_weight(
loaded_weight=loaded_weight_shard,
num_heads=self.num_kv_head_replicas,
shard_id=shard_id,
shard_offset=rank_shard_offset,
shard_size=rank_shard_size,
tp_rank=self.tp_rank,
use_presharded_weights=self.use_presharded_weights,
)
def weight_loader_v2( def weight_loader_v2(
self, self,
param: BasevLLMParameter, param: BasevLLMParameter,
...@@ -906,6 +935,9 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -906,6 +935,9 @@ class QKVParallelLinear(ColumnParallelLinear):
elif type(param) in (RowvLLMParameter, BasevLLMParameter): elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_qkv_weight(loaded_weight=loaded_weight) param.load_qkv_weight(loaded_weight=loaded_weight)
return return
elif isinstance(param, BlockQuantScaleParameter):
self._load_qkv_block_scale(param, loaded_weight)
return
# TODO: @dsikka - move to parameter.py # TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight) self._load_fused_module_from_checkpoint(param, loaded_weight)
return return
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
}
}
This diff is collapsed.
# coding=utf-8
# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" SGLang BailingMoENextN model."""
import logging
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM
from sglang.srt.utils import add_prefix
LoraConfig = None
logger = logging.getLogger(__name__)
class BailingMoEModelNextN(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
logger.warning(
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
)
quant_config = None
self.vocab_size = config.vocab_size
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("word_embeddings", prefix),
)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
self.decoder = BailingMoEBlock(
config,
0,
quant_config=quant_config,
# is_nextn=True,
prefix=add_prefix("decoder", prefix),
)
self.shared_head = nn.Module()
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.word_embeddings(input_ids)
else:
hidden_states = input_embeds
if hidden_states.shape[0] > 0:
hidden_states = self.eh_proj(
torch.cat(
(
self.enorm(hidden_states),
self.hnorm(forward_batch.spec_info.hidden_states),
),
dim=-1,
)
)
residual = None
hidden_states, residual = self.decoder(
positions, hidden_states, forward_batch, residual
)
if not forward_batch.forward_mode.is_idle():
if residual is not None:
hidden_states, _ = self.final_layernorm(hidden_states, residual)
else:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class BailingMoeForCausalLMNextN(BailingMoEForCausalLM):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
if hasattr(self, "determine_num_fused_shared_experts"):
# Asystem has determine_num_fused_shared_experts but theta does not.
self.determine_num_fused_shared_experts("BailingMoeForCausalLMNextN")
self.model = BailingMoEModelNextN(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
super().load_weights(weights, is_nextn=True)
EntryClass = [BailingMoeForCausalLMNextN]
...@@ -754,7 +754,12 @@ class ServerArgs: ...@@ -754,7 +754,12 @@ class ServerArgs:
) )
model_arch = self.get_hf_config().architectures[0] model_arch = self.get_hf_config().architectures[0]
if model_arch in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"]: if model_arch in [
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
"BailingMoeV2ForCausalLM",
"BailingMoeV2ForCausalLM",
]:
# Auto set draft_model_path DeepSeek-V3/R1 # Auto set draft_model_path DeepSeek-V3/R1
if self.speculative_draft_model_path is None: if self.speculative_draft_model_path is None:
self.speculative_draft_model_path = self.model_path self.speculative_draft_model_path = self.model_path
...@@ -2724,6 +2729,8 @@ def auto_choose_speculative_params(self: ServerArgs): ...@@ -2724,6 +2729,8 @@ def auto_choose_speculative_params(self: ServerArgs):
"DeepseekV3ForCausalLM", "DeepseekV3ForCausalLM",
"DeepseekV2ForCausalLM", "DeepseekV2ForCausalLM",
"GptOssForCausalLM", "GptOssForCausalLM",
"BailingMoeForCausalLM",
"BailingMoeV2ForCausalLM",
]: ]:
# The default value for deepseek and gpt-oss # The default value for deepseek and gpt-oss
return (3, 1, 4) return (3, 1, 4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment