Unverified Commit a31ea448 authored by zxy's avatar zxy Committed by GitHub
Browse files

support for interns1-mini (#9299)

parent 439df454
...@@ -21,6 +21,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch ...@@ -21,6 +21,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.internvl import InternVisionModel from sglang.srt.models.internvl import InternVisionModel
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.models.qwen3 import Qwen3ForCausalLM
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
from sglang.utils import logger from sglang.utils import logger
...@@ -70,6 +71,10 @@ class InternS1ForConditionalGeneration(nn.Module): ...@@ -70,6 +71,10 @@ class InternS1ForConditionalGeneration(nn.Module):
self.language_model = Qwen3MoeForCausalLM( self.language_model = Qwen3MoeForCausalLM(
config=config.text_config, quant_config=quant_config config=config.text_config, quant_config=quant_config
) )
elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
self.language_model = Qwen3ForCausalLM(
config=config.text_config, quant_config=quant_config
)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"{config.text_config.architectures[0]} is not implemented." f"{config.text_config.architectures[0]} is not implemented."
......
...@@ -327,8 +327,8 @@ class Qwen3ForCausalLM(nn.Module): ...@@ -327,8 +327,8 @@ class Qwen3ForCausalLM(nn.Module):
# For EAGLE3 support # For EAGLE3 support
self.capture_aux_hidden_states = False self.capture_aux_hidden_states = False
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self) -> nn.Embedding:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment