Unverified Commit f64b8e3e authored by yilian49's avatar yilian49 Committed by GitHub
Browse files

Support the internvl3.5 family models in sglang (#9705)

parent 53976fce
...@@ -6,11 +6,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -6,11 +6,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import sentencepiece as spm import sentencepiece as spm
from transformers import ( from transformers import (
TOKENIZER_MAPPING, TOKENIZER_MAPPING,
GptOssConfig,
LlamaConfig, LlamaConfig,
PretrainedConfig, PretrainedConfig,
PreTrainedTokenizer, PreTrainedTokenizer,
Qwen2Config, Qwen2Config,
Qwen3Config, Qwen3Config,
Qwen3MoeConfig,
) )
from sglang.utils import logger from sglang.utils import logger
...@@ -316,7 +318,11 @@ class InternVLChatConfig(PretrainedConfig): ...@@ -316,7 +318,11 @@ class InternVLChatConfig(PretrainedConfig):
elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM": elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
self.llm_config = Qwen2Config(**llm_config) self.llm_config = Qwen2Config(**llm_config)
elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM": elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
self.llm_config = Qwen3MoeConfig(**llm_config)
elif llm_config.get("architectures")[0] == "Qwen3ForCausalLM":
self.llm_config = Qwen3Config(**llm_config) self.llm_config = Qwen3Config(**llm_config)
elif llm_config.get("architectures")[0] == "GptOssForCausalLM":
self.llm_config = GptOssConfig(**llm_config)
else: else:
raise ValueError( raise ValueError(
"Unsupported architecture: {}".format( "Unsupported architecture: {}".format(
......
...@@ -26,8 +26,10 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -26,8 +26,10 @@ from sglang.srt.managers.schedule_batch import (
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_janus_pro import DropPath from sglang.srt.models.deepseek_janus_pro import DropPath
from sglang.srt.models.gpt_oss import GptOssForCausalLM
from sglang.srt.models.internlm2 import InternLM2ForCausalLM from sglang.srt.models.internlm2 import InternLM2ForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.models.qwen3 import Qwen3ForCausalLM
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
from sglang.utils import logger from sglang.utils import logger
...@@ -445,6 +447,14 @@ class InternVLChatModel(nn.Module): ...@@ -445,6 +447,14 @@ class InternVLChatModel(nn.Module):
self.language_model = Qwen3MoeForCausalLM( self.language_model = Qwen3MoeForCausalLM(
config=config.llm_config, quant_config=quant_config config=config.llm_config, quant_config=quant_config
) )
elif config.llm_config.architectures[0] == "GptOssForCausalLM":
self.language_model = GptOssForCausalLM(
config=config.llm_config, quant_config=quant_config
)
elif config.llm_config.architectures[0] == "Qwen3ForCausalLM":
self.language_model = Qwen3ForCausalLM(
config=config.llm_config, quant_config=quant_config
)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"{config.llm_config.architectures[0]} is not implemented." f"{config.llm_config.architectures[0]} is not implemented."
...@@ -577,6 +587,15 @@ class InternVLChatModel(nn.Module): ...@@ -577,6 +587,15 @@ class InternVLChatModel(nn.Module):
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts, num_experts=self.config.num_experts,
) )
elif "Qwen3ForCausalLM" in self.config.llm_config.architectures:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
...@@ -661,6 +680,15 @@ class InternVLChatModel(nn.Module): ...@@ -661,6 +680,15 @@ class InternVLChatModel(nn.Module):
loaded_params.add(name) loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params unloaded_params = params_dict.keys() - loaded_params
# Skip params that are created by quantization wrappers and are not expected in the ckpt
_quant_only_fragments = (
"weight_scale", # per-matrix FP8 scales (e.g., w2_weight_scale, w13_weight_scale)
)
unloaded_params = {
n
for n in unloaded_params
if not any(frag in n for frag in _quant_only_fragments)
}
if unloaded_params: if unloaded_params:
raise RuntimeError( raise RuntimeError(
f"Some weights are not initialized from checkpoints: {unloaded_params}" f"Some weights are not initialized from checkpoints: {unloaded_params}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment