Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f64b8e3e
Unverified
Commit
f64b8e3e
authored
Sep 02, 2025
by
yilian49
Committed by
GitHub
Sep 02, 2025
Browse files
Support the internvl3.5 family models in sglang (#9705)
parent
53976fce
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
0 deletions
+34
-0
python/sglang/srt/configs/internvl.py
python/sglang/srt/configs/internvl.py
+6
-0
python/sglang/srt/models/internvl.py
python/sglang/srt/models/internvl.py
+28
-0
No files found.
python/sglang/srt/configs/internvl.py
View file @
f64b8e3e
...
...
@@ -6,11 +6,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import
sentencepiece
as
spm
from
transformers
import
(
TOKENIZER_MAPPING
,
GptOssConfig
,
LlamaConfig
,
PretrainedConfig
,
PreTrainedTokenizer
,
Qwen2Config
,
Qwen3Config
,
Qwen3MoeConfig
,
)
from
sglang.utils
import
logger
...
...
@@ -316,7 +318,11 @@ class InternVLChatConfig(PretrainedConfig):
elif
llm_config
.
get
(
"architectures"
)[
0
]
==
"Qwen2ForCausalLM"
:
self
.
llm_config
=
Qwen2Config
(
**
llm_config
)
elif
llm_config
.
get
(
"architectures"
)[
0
]
==
"Qwen3MoeForCausalLM"
:
self
.
llm_config
=
Qwen3MoeConfig
(
**
llm_config
)
elif
llm_config
.
get
(
"architectures"
)[
0
]
==
"Qwen3ForCausalLM"
:
self
.
llm_config
=
Qwen3Config
(
**
llm_config
)
elif
llm_config
.
get
(
"architectures"
)[
0
]
==
"GptOssForCausalLM"
:
self
.
llm_config
=
GptOssConfig
(
**
llm_config
)
else
:
raise
ValueError
(
"Unsupported architecture: {}"
.
format
(
...
...
python/sglang/srt/models/internvl.py
View file @
f64b8e3e
...
...
@@ -26,8 +26,10 @@ from sglang.srt.managers.schedule_batch import (
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.deepseek_janus_pro
import
DropPath
from
sglang.srt.models.gpt_oss
import
GptOssForCausalLM
from
sglang.srt.models.internlm2
import
InternLM2ForCausalLM
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
from
sglang.srt.models.qwen3
import
Qwen3ForCausalLM
from
sglang.srt.models.qwen3_moe
import
Qwen3MoeForCausalLM
from
sglang.utils
import
logger
...
...
@@ -445,6 +447,14 @@ class InternVLChatModel(nn.Module):
self
.
language_model
=
Qwen3MoeForCausalLM
(
config
=
config
.
llm_config
,
quant_config
=
quant_config
)
elif
config
.
llm_config
.
architectures
[
0
]
==
"GptOssForCausalLM"
:
self
.
language_model
=
GptOssForCausalLM
(
config
=
config
.
llm_config
,
quant_config
=
quant_config
)
elif
config
.
llm_config
.
architectures
[
0
]
==
"Qwen3ForCausalLM"
:
self
.
language_model
=
Qwen3ForCausalLM
(
config
=
config
.
llm_config
,
quant_config
=
quant_config
)
else
:
raise
NotImplementedError
(
f
"
{
config
.
llm_config
.
architectures
[
0
]
}
is not implemented."
...
...
@@ -577,6 +587,15 @@ class InternVLChatModel(nn.Module):
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
,
)
elif
"Qwen3ForCausalLM"
in
self
.
config
.
llm_config
.
architectures
:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
...
...
@@ -661,6 +680,15 @@ class InternVLChatModel(nn.Module):
loaded_params
.
add
(
name
)
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
# Skip params that are created by quantization wrappers and are not expected in the ckpt
_quant_only_fragments
=
(
"weight_scale"
,
# per-matrix FP8 scales (e.g., w2_weight_scale, w13_weight_scale)
)
unloaded_params
=
{
n
for
n
in
unloaded_params
if
not
any
(
frag
in
n
for
frag
in
_quant_only_fragments
)
}
if
unloaded_params
:
raise
RuntimeError
(
f
"Some weights are not initialized from checkpoints:
{
unloaded_params
}
"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment