Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5e194b21
Unverified
Commit
5e194b21
authored
Aug 31, 2025
by
Guoyuan Lin
Committed by
GitHub
Aug 30, 2025
Browse files
[Model] Support Meituan LongCat-Flash && LongCat-Flash-MTP (#9824)
parent
fd5ce576
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1940 additions
and
11 deletions
+1940
-11
python/sglang/srt/configs/__init__.py
python/sglang/srt/configs/__init__.py
+2
-0
python/sglang/srt/configs/longcat_flash.py
python/sglang/srt/configs/longcat_flash.py
+104
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+12
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+2
-0
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+74
-0
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+23
-10
python/sglang/srt/layers/quantization/utils.py
python/sglang/srt/layers/quantization/utils.py
+13
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+4
-1
python/sglang/srt/models/longcat_flash.py
python/sglang/srt/models/longcat_flash.py
+1015
-0
python/sglang/srt/models/longcat_flash_nextn.py
python/sglang/srt/models/longcat_flash_nextn.py
+691
-0
No files found.
python/sglang/srt/configs/__init__.py
View file @
5e194b21
...
@@ -5,6 +5,7 @@ from sglang.srt.configs.exaone import ExaoneConfig
...
@@ -5,6 +5,7 @@ from sglang.srt.configs.exaone import ExaoneConfig
from
sglang.srt.configs.janus_pro
import
MultiModalityConfig
from
sglang.srt.configs.janus_pro
import
MultiModalityConfig
from
sglang.srt.configs.kimi_vl
import
KimiVLConfig
from
sglang.srt.configs.kimi_vl
import
KimiVLConfig
from
sglang.srt.configs.kimi_vl_moonvit
import
MoonViTConfig
from
sglang.srt.configs.kimi_vl_moonvit
import
MoonViTConfig
from
sglang.srt.configs.longcat_flash
import
LongcatFlashConfig
from
sglang.srt.configs.step3_vl
import
(
from
sglang.srt.configs.step3_vl
import
(
Step3TextConfig
,
Step3TextConfig
,
Step3VisionEncoderConfig
,
Step3VisionEncoderConfig
,
...
@@ -16,6 +17,7 @@ __all__ = [
...
@@ -16,6 +17,7 @@ __all__ = [
"ChatGLMConfig"
,
"ChatGLMConfig"
,
"DbrxConfig"
,
"DbrxConfig"
,
"DeepseekVL2Config"
,
"DeepseekVL2Config"
,
"LongcatFlashConfig"
,
"MultiModalityConfig"
,
"MultiModalityConfig"
,
"KimiVLConfig"
,
"KimiVLConfig"
,
"MoonViTConfig"
,
"MoonViTConfig"
,
...
...
python/sglang/srt/configs/longcat_flash.py
0 → 100644
View file @
5e194b21
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
FLASH_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{}
class
LongcatFlashConfig
(
PretrainedConfig
):
model_type
=
"longcat_flash"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
vocab_size
=
131072
,
hidden_size
=
6144
,
intermediate_size
=
None
,
ffn_hidden_size
=
12288
,
expert_ffn_hidden_size
=
2048
,
num_layers
=
28
,
num_hidden_layers
=
None
,
num_attention_heads
=
64
,
ep_size
=
1
,
kv_lora_rank
=
512
,
q_lora_rank
=
1536
,
qk_rope_head_dim
=
128
,
qk_nope_head_dim
=
128
,
v_head_dim
=
128
,
n_routed_experts
=
512
,
moe_topk
=
12
,
norm_topk_prob
=
False
,
max_position_embeddings
=
131072
,
rms_norm_eps
=
1e-05
,
use_cache
=
True
,
pad_token_id
=
None
,
bos_token_id
=
1
,
eos_token_id
=
2
,
pretraining_tp
=
1
,
tie_word_embeddings
=
False
,
rope_theta
=
10000000.0
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
mla_scale_q_lora
=
True
,
mla_scale_kv_lora
=
True
,
torch_dtype
=
"bfloat16"
,
params_dtype
=
"bfloat16"
,
rounter_params_dtype
=
"float32"
,
router_bias
=
False
,
topk_method
=
None
,
routed_scaling_factor
=
6.0
,
zero_expert_num
=
256
,
zero_expert_type
=
"identity"
,
nextn_use_scmoe
=
False
,
num_nextn_predict_layers
=
1
,
**
kwargs
,
):
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
torch_dtype
=
torch_dtype
,
params_dtype
=
params_dtype
,
rounter_params_dtype
=
rounter_params_dtype
,
topk_method
=
topk_method
,
router_bias
=
router_bias
,
nextn_use_scmoe
=
nextn_use_scmoe
,
num_nextn_predict_layers
=
num_nextn_predict_layers
,
**
kwargs
,
)
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
(
num_hidden_layers
if
num_hidden_layers
is
not
None
else
num_layers
)
self
.
intermediate_size
=
(
intermediate_size
if
intermediate_size
is
not
None
else
ffn_hidden_size
)
self
.
moe_intermediate_size
=
expert_ffn_hidden_size
self
.
num_attention_heads
=
num_attention_heads
self
.
ep_size
=
ep_size
self
.
kv_lora_rank
=
kv_lora_rank
self
.
q_lora_rank
=
q_lora_rank
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
n_routed_experts
=
n_routed_experts
self
.
moe_topk
=
moe_topk
self
.
norm_topk_prob
=
norm_topk_prob
self
.
rms_norm_eps
=
rms_norm_eps
self
.
pretraining_tp
=
pretraining_tp
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
mla_scale_q_lora
=
mla_scale_q_lora
self
.
mla_scale_kv_lora
=
mla_scale_kv_lora
self
.
zero_expert_num
=
zero_expert_num
self
.
zero_expert_type
=
zero_expert_type
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
hidden_act
=
"silu"
python/sglang/srt/configs/model_config.py
View file @
5e194b21
...
@@ -132,6 +132,13 @@ class ModelConfig:
...
@@ -132,6 +132,13 @@ class ModelConfig:
if
is_draft_model
and
self
.
hf_config
.
architectures
[
0
]
==
"Glm4MoeForCausalLM"
:
if
is_draft_model
and
self
.
hf_config
.
architectures
[
0
]
==
"Glm4MoeForCausalLM"
:
self
.
hf_config
.
architectures
[
0
]
=
"Glm4MoeForCausalLMNextN"
self
.
hf_config
.
architectures
[
0
]
=
"Glm4MoeForCausalLMNextN"
if
(
is_draft_model
and
self
.
hf_config
.
architectures
[
0
]
==
"LongcatFlashForCausalLM"
):
self
.
hf_config
.
architectures
[
0
]
=
"LongcatFlashForCausalLMNextN"
self
.
hf_config
.
num_hidden_layers
=
self
.
hf_config
.
num_nextn_predict_layers
if
is_draft_model
and
self
.
hf_config
.
architectures
[
0
]
==
"MiMoForCausalLM"
:
if
is_draft_model
and
self
.
hf_config
.
architectures
[
0
]
==
"MiMoForCausalLM"
:
self
.
hf_config
.
architectures
[
0
]
=
"MiMoMTP"
self
.
hf_config
.
architectures
[
0
]
=
"MiMoMTP"
if
(
if
(
...
@@ -199,6 +206,8 @@ class ModelConfig:
...
@@ -199,6 +206,8 @@ class ModelConfig:
"DeepseekV2ForCausalLM"
in
self
.
hf_config
.
architectures
"DeepseekV2ForCausalLM"
in
self
.
hf_config
.
architectures
or
"DeepseekV3ForCausalLM"
in
self
.
hf_config
.
architectures
or
"DeepseekV3ForCausalLM"
in
self
.
hf_config
.
architectures
or
"DeepseekV3ForCausalLMNextN"
in
self
.
hf_config
.
architectures
or
"DeepseekV3ForCausalLMNextN"
in
self
.
hf_config
.
architectures
or
"LongcatFlashForCausalLM"
in
self
.
hf_config
.
architectures
or
"LongcatFlashForCausalLMNextN"
in
self
.
hf_config
.
architectures
):
):
self
.
head_dim
=
256
self
.
head_dim
=
256
self
.
attention_arch
=
AttentionArch
.
MLA
self
.
attention_arch
=
AttentionArch
.
MLA
...
@@ -270,6 +279,9 @@ class ModelConfig:
...
@@ -270,6 +279,9 @@ class ModelConfig:
self
.
num_key_value_heads
=
self
.
num_attention_heads
self
.
num_key_value_heads
=
self
.
num_attention_heads
self
.
hidden_size
=
self
.
hf_text_config
.
hidden_size
self
.
hidden_size
=
self
.
hf_text_config
.
hidden_size
self
.
num_hidden_layers
=
self
.
hf_text_config
.
num_hidden_layers
self
.
num_hidden_layers
=
self
.
hf_text_config
.
num_hidden_layers
self
.
num_attention_layers
=
self
.
num_hidden_layers
if
"LongcatFlashForCausalLM"
in
self
.
hf_config
.
architectures
:
self
.
num_attention_layers
=
self
.
num_hidden_layers
*
2
self
.
num_nextn_predict_layers
=
getattr
(
self
.
num_nextn_predict_layers
=
getattr
(
self
.
hf_text_config
,
"num_nextn_predict_layers"
,
None
self
.
hf_text_config
,
"num_nextn_predict_layers"
,
None
)
)
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
5e194b21
...
@@ -40,6 +40,7 @@ from sglang.srt.configs import (
...
@@ -40,6 +40,7 @@ from sglang.srt.configs import (
DeepseekVL2Config
,
DeepseekVL2Config
,
ExaoneConfig
,
ExaoneConfig
,
KimiVLConfig
,
KimiVLConfig
,
LongcatFlashConfig
,
MultiModalityConfig
,
MultiModalityConfig
,
Step3VLConfig
,
Step3VLConfig
,
)
)
...
@@ -56,6 +57,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
...
@@ -56,6 +57,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
KimiVLConfig
.
model_type
:
KimiVLConfig
,
KimiVLConfig
.
model_type
:
KimiVLConfig
,
InternVLChatConfig
.
model_type
:
InternVLChatConfig
,
InternVLChatConfig
.
model_type
:
InternVLChatConfig
,
Step3VLConfig
.
model_type
:
Step3VLConfig
,
Step3VLConfig
.
model_type
:
Step3VLConfig
,
LongcatFlashConfig
.
model_type
:
LongcatFlashConfig
,
}
}
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
...
...
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
5e194b21
...
@@ -1362,3 +1362,77 @@ def moe_ep_deepgemm_preprocess(
...
@@ -1362,3 +1362,77 @@ def moe_ep_deepgemm_preprocess(
gateup_input
,
gateup_input
,
gateup_input_scale
,
gateup_input_scale
,
)
)
@
triton
.
jit
def
compute_identity_kernel
(
top_k
,
hidden_states_ptr
,
expert_scales_ptr
,
num_tokens
,
output_ptr
,
hidden_dim
,
scales_stride
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
batch_id
=
pid
//
(
hidden_dim
//
BLOCK_SIZE
)
dim_offset
=
pid
%
(
hidden_dim
//
BLOCK_SIZE
)
*
BLOCK_SIZE
if
batch_id
>=
num_tokens
or
dim_offset
>=
hidden_dim
:
return
h
=
tl
.
load
(
hidden_states_ptr
+
batch_id
*
hidden_dim
+
dim_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
),
mask
=
(
dim_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
))
<
hidden_dim
,
)
result
=
tl
.
zeros
([
BLOCK_SIZE
],
dtype
=
tl
.
float32
)
for
i
in
range
(
top_k
):
scale
=
tl
.
load
(
expert_scales_ptr
+
batch_id
*
scales_stride
+
i
)
result
+=
h
*
scale
tl
.
store
(
output_ptr
+
batch_id
*
hidden_dim
+
dim_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
),
result
,
mask
=
(
dim_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
))
<
hidden_dim
,
)
def
zero_experts_compute_triton
(
expert_indices
,
expert_scales
,
num_experts
,
zero_expert_type
,
hidden_states
):
N
=
expert_indices
.
numel
()
top_k
=
expert_indices
.
size
(
-
1
)
grid
=
lambda
meta
:
(
triton
.
cdiv
(
N
,
meta
[
"BLOCK_SIZE"
]),)
if
zero_expert_type
==
"identity"
:
zero_expert_mask
=
expert_indices
<
num_experts
zero_expert_scales
=
expert_scales
.
clone
()
zero_expert_scales
[
zero_expert_mask
]
=
0.0
normal_expert_mask
=
expert_indices
>=
num_experts
expert_indices
[
normal_expert_mask
]
=
0
expert_scales
[
normal_expert_mask
]
=
0.0
output
=
torch
.
zeros_like
(
hidden_states
).
to
(
hidden_states
.
device
)
hidden_dim
=
hidden_states
.
size
(
-
1
)
num_tokens
=
hidden_states
.
size
(
0
)
grid
=
lambda
meta
:
(
num_tokens
*
(
hidden_dim
//
meta
[
"BLOCK_SIZE"
]),)
compute_identity_kernel
[
grid
](
top_k
,
hidden_states
,
zero_expert_scales
,
num_tokens
,
output
,
hidden_dim
,
zero_expert_scales
.
stride
(
0
),
BLOCK_SIZE
=
256
,
)
return
output
python/sglang/srt/layers/moe/topk.py
View file @
5e194b21
...
@@ -357,17 +357,28 @@ def fused_topk_torch_native(
...
@@ -357,17 +357,28 @@ def fused_topk_torch_native(
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
correction_bias
:
torch
.
Tensor
=
None
,
):
):
assert
(
if
correction_bias
is
not
None
:
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
]
n_routed_experts
=
gating_output
.
shape
[
-
1
]
),
f
"Number of tokens mismatch,
{
hidden_states
.
shape
=
}
vs
{
gating_output
.
shape
=
}
"
scores
=
gating_output
.
softmax
(
dim
=-
1
)
M
,
_
=
hidden_states
.
shape
scores_for_choice
=
scores
.
view
(
topk_weights
=
torch
.
empty
(
-
1
,
n_routed_experts
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
+
correction_bias
.
unsqueeze
(
0
)
)
topk_ids
=
torch
.
topk
(
scores_for_choice
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)[
1
]
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
topk_weights
=
scores
.
gather
(
1
,
topk_ids
)
topk_weights
=
F
.
softmax
(
gating_output
.
float
(),
dim
=-
1
)
else
:
topk_weights
,
topk_ids
=
torch
.
topk
(
topk_weights
,
topk
,
dim
=-
1
)
assert
(
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
]
),
f
"Number of tokens mismatch,
{
hidden_states
.
shape
=
}
vs
{
gating_output
.
shape
=
}
"
M
,
_
=
hidden_states
.
shape
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
topk_weights
=
F
.
softmax
(
gating_output
.
float
(),
dim
=-
1
)
topk_weights
,
topk_ids
=
torch
.
topk
(
topk_weights
,
topk
,
dim
=-
1
)
if
renormalize
:
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
...
@@ -380,6 +391,7 @@ def fused_topk_cpu(
...
@@ -380,6 +391,7 @@ def fused_topk_cpu(
renormalize
:
bool
,
renormalize
:
bool
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
correction_bias
:
torch
.
Tensor
=
None
,
):
):
topk_weights
,
topk_ids
=
torch
.
ops
.
sgl_kernel
.
topk_softmax_cpu
(
topk_weights
,
topk_ids
=
torch
.
ops
.
sgl_kernel
.
topk_softmax_cpu
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -825,6 +837,7 @@ def select_experts(
...
@@ -825,6 +837,7 @@ def select_experts(
gating_output
=
router_logits
,
gating_output
=
router_logits
,
topk
=
top_k
,
topk
=
top_k
,
renormalize
=
renormalize
,
renormalize
=
renormalize
,
correction_bias
=
correction_bias
,
)
)
elif
custom_routing_function
is
None
:
elif
custom_routing_function
is
None
:
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
...
...
python/sglang/srt/layers/quantization/utils.py
View file @
5e194b21
...
@@ -77,6 +77,19 @@ def is_layer_skipped(
...
@@ -77,6 +77,19 @@ def is_layer_skipped(
)
)
else
:
else
:
is_skipped
=
prefix
in
ignored_layers
is_skipped
=
prefix
in
ignored_layers
if
"gate_up_proj"
in
prefix
:
prefix_gate
=
prefix
.
replace
(
"gate_up_proj"
,
"gate_proj"
)
prefix_up
=
prefix
.
replace
(
"gate_up_proj"
,
"up_proj"
)
if
prefix_gate
in
ignored_layers
and
prefix_up
in
ignored_layers
:
is_skipped
=
True
elif
"experts"
in
prefix
:
is_skipped
=
any
(
[
prefix
in
layer_name
for
layer_name
in
ignored_layers
if
"experts"
in
layer_name
]
)
assert
is_skipped
is
not
None
assert
is_skipped
is
not
None
return
is_skipped
return
is_skipped
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
5e194b21
...
@@ -307,7 +307,10 @@ class ModelRunner:
...
@@ -307,7 +307,10 @@ class ModelRunner:
model_num_layers
=
(
model_num_layers
=
(
self
.
model_config
.
num_nextn_predict_layers
self
.
model_config
.
num_nextn_predict_layers
if
self
.
is_draft_worker
and
model_has_mtp_layers
if
self
.
is_draft_worker
and
model_has_mtp_layers
else
self
.
model_config
.
num_hidden_layers
else
max
(
self
.
model_config
.
num_hidden_layers
,
self
.
model_config
.
num_attention_layers
,
)
)
)
self
.
start_layer
=
getattr
(
self
.
model
,
"start_layer"
,
0
)
self
.
start_layer
=
getattr
(
self
.
model
,
"start_layer"
,
0
)
self
.
end_layer
=
getattr
(
self
.
model
,
"end_layer"
,
model_num_layers
)
self
.
end_layer
=
getattr
(
self
.
model
,
"end_layer"
,
model_num_layers
)
...
...
python/sglang/srt/models/longcat_flash.py
0 → 100644
View file @
5e194b21
This diff is collapsed.
Click to expand it.
python/sglang/srt/models/longcat_flash_nextn.py
0 → 100644
View file @
5e194b21
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment