Unverified Commit 5e194b21 authored by Guoyuan Lin's avatar Guoyuan Lin Committed by GitHub
Browse files

[Model] Support Meituan LongCat-Flash && LongCat-Flash-MTP (#9824)

parent fd5ce576
...@@ -5,6 +5,7 @@ from sglang.srt.configs.exaone import ExaoneConfig ...@@ -5,6 +5,7 @@ from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
from sglang.srt.configs.step3_vl import ( from sglang.srt.configs.step3_vl import (
Step3TextConfig, Step3TextConfig,
Step3VisionEncoderConfig, Step3VisionEncoderConfig,
...@@ -16,6 +17,7 @@ __all__ = [ ...@@ -16,6 +17,7 @@ __all__ = [
"ChatGLMConfig", "ChatGLMConfig",
"DbrxConfig", "DbrxConfig",
"DeepseekVL2Config", "DeepseekVL2Config",
"LongcatFlashConfig",
"MultiModalityConfig", "MultiModalityConfig",
"KimiVLConfig", "KimiVLConfig",
"MoonViTConfig", "MoonViTConfig",
......
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
FLASH_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class LongcatFlashConfig(PretrainedConfig):
model_type = "longcat_flash"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=131072,
hidden_size=6144,
intermediate_size=None,
ffn_hidden_size=12288,
expert_ffn_hidden_size=2048,
num_layers=28,
num_hidden_layers=None,
num_attention_heads=64,
ep_size=1,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=128,
qk_nope_head_dim=128,
v_head_dim=128,
n_routed_experts=512,
moe_topk=12,
norm_topk_prob=False,
max_position_embeddings=131072,
rms_norm_eps=1e-05,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mla_scale_q_lora=True,
mla_scale_kv_lora=True,
torch_dtype="bfloat16",
params_dtype="bfloat16",
rounter_params_dtype="float32",
router_bias=False,
topk_method=None,
routed_scaling_factor=6.0,
zero_expert_num=256,
zero_expert_type="identity",
nextn_use_scmoe=False,
num_nextn_predict_layers=1,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
torch_dtype=torch_dtype,
params_dtype=params_dtype,
rounter_params_dtype=rounter_params_dtype,
topk_method=topk_method,
router_bias=router_bias,
nextn_use_scmoe=nextn_use_scmoe,
num_nextn_predict_layers=num_nextn_predict_layers,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_hidden_layers = (
num_hidden_layers if num_hidden_layers is not None else num_layers
)
self.intermediate_size = (
intermediate_size if intermediate_size is not None else ffn_hidden_size
)
self.moe_intermediate_size = expert_ffn_hidden_size
self.num_attention_heads = num_attention_heads
self.ep_size = ep_size
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.n_routed_experts = n_routed_experts
self.moe_topk = moe_topk
self.norm_topk_prob = norm_topk_prob
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mla_scale_q_lora = mla_scale_q_lora
self.mla_scale_kv_lora = mla_scale_kv_lora
self.zero_expert_num = zero_expert_num
self.zero_expert_type = zero_expert_type
self.routed_scaling_factor = routed_scaling_factor
self.hidden_act = "silu"
...@@ -132,6 +132,13 @@ class ModelConfig: ...@@ -132,6 +132,13 @@ class ModelConfig:
if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM": if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN" self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"
if (
is_draft_model
and self.hf_config.architectures[0] == "LongcatFlashForCausalLM"
):
self.hf_config.architectures[0] = "LongcatFlashForCausalLMNextN"
self.hf_config.num_hidden_layers = self.hf_config.num_nextn_predict_layers
if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM": if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
self.hf_config.architectures[0] = "MiMoMTP" self.hf_config.architectures[0] = "MiMoMTP"
if ( if (
...@@ -199,6 +206,8 @@ class ModelConfig: ...@@ -199,6 +206,8 @@ class ModelConfig:
"DeepseekV2ForCausalLM" in self.hf_config.architectures "DeepseekV2ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLM" in self.hf_config.architectures or "DeepseekV3ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
or "LongcatFlashForCausalLM" in self.hf_config.architectures
or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures
): ):
self.head_dim = 256 self.head_dim = 256
self.attention_arch = AttentionArch.MLA self.attention_arch = AttentionArch.MLA
...@@ -270,6 +279,9 @@ class ModelConfig: ...@@ -270,6 +279,9 @@ class ModelConfig:
self.num_key_value_heads = self.num_attention_heads self.num_key_value_heads = self.num_attention_heads
self.hidden_size = self.hf_text_config.hidden_size self.hidden_size = self.hf_text_config.hidden_size
self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.num_attention_layers = self.num_hidden_layers
if "LongcatFlashForCausalLM" in self.hf_config.architectures:
self.num_attention_layers = self.num_hidden_layers * 2
self.num_nextn_predict_layers = getattr( self.num_nextn_predict_layers = getattr(
self.hf_text_config, "num_nextn_predict_layers", None self.hf_text_config, "num_nextn_predict_layers", None
) )
......
...@@ -40,6 +40,7 @@ from sglang.srt.configs import ( ...@@ -40,6 +40,7 @@ from sglang.srt.configs import (
DeepseekVL2Config, DeepseekVL2Config,
ExaoneConfig, ExaoneConfig,
KimiVLConfig, KimiVLConfig,
LongcatFlashConfig,
MultiModalityConfig, MultiModalityConfig,
Step3VLConfig, Step3VLConfig,
) )
...@@ -56,6 +57,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ...@@ -56,6 +57,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
KimiVLConfig.model_type: KimiVLConfig, KimiVLConfig.model_type: KimiVLConfig,
InternVLChatConfig.model_type: InternVLChatConfig, InternVLChatConfig.model_type: InternVLChatConfig,
Step3VLConfig.model_type: Step3VLConfig, Step3VLConfig.model_type: Step3VLConfig,
LongcatFlashConfig.model_type: LongcatFlashConfig,
} }
for name, cls in _CONFIG_REGISTRY.items(): for name, cls in _CONFIG_REGISTRY.items():
......
...@@ -1362,3 +1362,77 @@ def moe_ep_deepgemm_preprocess( ...@@ -1362,3 +1362,77 @@ def moe_ep_deepgemm_preprocess(
gateup_input, gateup_input,
gateup_input_scale, gateup_input_scale,
) )
@triton.jit
def compute_identity_kernel(
top_k,
hidden_states_ptr,
expert_scales_ptr,
num_tokens,
output_ptr,
hidden_dim,
scales_stride,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
batch_id = pid // (hidden_dim // BLOCK_SIZE)
dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE
if batch_id >= num_tokens or dim_offset >= hidden_dim:
return
h = tl.load(
hidden_states_ptr
+ batch_id * hidden_dim
+ dim_offset
+ tl.arange(0, BLOCK_SIZE),
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
)
result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for i in range(top_k):
scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
result += h * scale
tl.store(
output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE),
result,
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
)
def zero_experts_compute_triton(
expert_indices, expert_scales, num_experts, zero_expert_type, hidden_states
):
N = expert_indices.numel()
top_k = expert_indices.size(-1)
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
if zero_expert_type == "identity":
zero_expert_mask = expert_indices < num_experts
zero_expert_scales = expert_scales.clone()
zero_expert_scales[zero_expert_mask] = 0.0
normal_expert_mask = expert_indices >= num_experts
expert_indices[normal_expert_mask] = 0
expert_scales[normal_expert_mask] = 0.0
output = torch.zeros_like(hidden_states).to(hidden_states.device)
hidden_dim = hidden_states.size(-1)
num_tokens = hidden_states.size(0)
grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),)
compute_identity_kernel[grid](
top_k,
hidden_states,
zero_expert_scales,
num_tokens,
output,
hidden_dim,
zero_expert_scales.stride(0),
BLOCK_SIZE=256,
)
return output
...@@ -357,7 +357,17 @@ def fused_topk_torch_native( ...@@ -357,7 +357,17 @@ def fused_topk_torch_native(
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
correction_bias: torch.Tensor = None,
): ):
if correction_bias is not None:
n_routed_experts = gating_output.shape[-1]
scores = gating_output.softmax(dim=-1)
scores_for_choice = scores.view(
-1, n_routed_experts
) + correction_bias.unsqueeze(0)
topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
topk_weights = scores.gather(1, topk_ids)
else:
assert ( assert (
hidden_states.shape[0] == gating_output.shape[0] hidden_states.shape[0] == gating_output.shape[0]
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}" ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
...@@ -368,6 +378,7 @@ def fused_topk_torch_native( ...@@ -368,6 +378,7 @@ def fused_topk_torch_native(
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
topk_weights = F.softmax(gating_output.float(), dim=-1) topk_weights = F.softmax(gating_output.float(), dim=-1)
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1) topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids return topk_weights, topk_ids
...@@ -380,6 +391,7 @@ def fused_topk_cpu( ...@@ -380,6 +391,7 @@ def fused_topk_cpu(
renormalize: bool, renormalize: bool,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
correction_bias: torch.Tensor = None,
): ):
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu( topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -825,6 +837,7 @@ def select_experts( ...@@ -825,6 +837,7 @@ def select_experts(
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=top_k,
renormalize=renormalize, renormalize=renormalize,
correction_bias=correction_bias,
) )
elif custom_routing_function is None: elif custom_routing_function is None:
assert not apply_routed_scaling_factor_on_output, "Not implemented" assert not apply_routed_scaling_factor_on_output, "Not implemented"
......
...@@ -77,6 +77,19 @@ def is_layer_skipped( ...@@ -77,6 +77,19 @@ def is_layer_skipped(
) )
else: else:
is_skipped = prefix in ignored_layers is_skipped = prefix in ignored_layers
if "gate_up_proj" in prefix:
prefix_gate = prefix.replace("gate_up_proj", "gate_proj")
prefix_up = prefix.replace("gate_up_proj", "up_proj")
if prefix_gate in ignored_layers and prefix_up in ignored_layers:
is_skipped = True
elif "experts" in prefix:
is_skipped = any(
[
prefix in layer_name
for layer_name in ignored_layers
if "experts" in layer_name
]
)
assert is_skipped is not None assert is_skipped is not None
return is_skipped return is_skipped
......
...@@ -307,7 +307,10 @@ class ModelRunner: ...@@ -307,7 +307,10 @@ class ModelRunner:
model_num_layers = ( model_num_layers = (
self.model_config.num_nextn_predict_layers self.model_config.num_nextn_predict_layers
if self.is_draft_worker and model_has_mtp_layers if self.is_draft_worker and model_has_mtp_layers
else self.model_config.num_hidden_layers else max(
self.model_config.num_hidden_layers,
self.model_config.num_attention_layers,
)
) )
self.start_layer = getattr(self.model, "start_layer", 0) self.start_layer = getattr(self.model, "start_layer", 0)
self.end_layer = getattr(self.model, "end_layer", model_num_layers) self.end_layer = getattr(self.model, "end_layer", model_num_layers)
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment