Unverified Commit 3409aaab authored by xm:D's avatar xm:D Committed by GitHub
Browse files

Support InternVL3 (#5350)


Co-authored-by: default avatarMick <mickjagger19@icloud.com>
Co-authored-by: default avatarChayenne <zhaochen20@outlook.com>
parent 73dcf2b3
...@@ -270,6 +270,29 @@ register_chat_template( ...@@ -270,6 +270,29 @@ register_chat_template(
) )
) )
register_chat_template(
ChatTemplate(
name="janus",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"",
"",
),
"user": (
"<|User|>",
"",
),
"assistant": (
"<|Assistant|>",
"<|end▁of▁sentence|>",
),
},
stop_str=("<|end▁of▁sentence|>",),
image_token="<image_placeholder>\n",
)
)
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token. # The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
...@@ -395,6 +418,20 @@ register_chat_template( ...@@ -395,6 +418,20 @@ register_chat_template(
) )
) )
# Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
register_chat_template(
ChatTemplate(
name="internvl-2-5",
default_system_prompt="你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
stop_str=["<|im_end|>", "<|action_end|>"],
)
)
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
name="granite-3-instruct", name="granite-3-instruct",
...@@ -565,6 +602,13 @@ def match_gemma3_instruct(model_path: str): ...@@ -565,6 +602,13 @@ def match_gemma3_instruct(model_path: str):
return get_chat_template("gemma-it") return get_chat_template("gemma-it")
@register_chat_template_matching_function
def match_internvl_chat(model_path: str):
model_path = model_path.lower()
if "internvl" in model_path:
return get_chat_template("internvl-2-5")
if __name__ == "__main__": if __name__ == "__main__":
messages = [ messages = [
{"role": "system", "content": None}, # None means default {"role": "system", "content": None}, # None means default
......
This diff is collapsed.
...@@ -538,6 +538,7 @@ multimodal_model_archs = [ ...@@ -538,6 +538,7 @@ multimodal_model_archs = [
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"CLIPModel", "CLIPModel",
"KimiVLForConditionalGeneration", "KimiVLForConditionalGeneration",
"InternVLChatModel",
] ]
......
...@@ -48,6 +48,7 @@ class SeparatorStyle(IntEnum): ...@@ -48,6 +48,7 @@ class SeparatorStyle(IntEnum):
DeepSeekVL2 = auto() DeepSeekVL2 = auto()
QWEN2_VL_EMBED = auto() QWEN2_VL_EMBED = auto()
GEMMA3 = auto() GEMMA3 = auto()
MPT = auto()
@dataclasses.dataclass @dataclasses.dataclass
...@@ -327,6 +328,16 @@ class Conversation: ...@@ -327,6 +328,16 @@ class Conversation:
ret += role ret += role
return ret return ret
elif self.sep_style == SeparatorStyle.MPT:
ret = system_prompt + self.sep
for role, message in self.messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
return ret
else: else:
raise ValueError(f"Invalid style: {self.sep_style}") raise ValueError(f"Invalid style: {self.sep_style}")
...@@ -570,8 +581,11 @@ def generate_chat_conv( ...@@ -570,8 +581,11 @@ def generate_chat_conv(
real_content += "\n" # for video real_content += "\n" # for video
real_content += content.text real_content += content.text
elif content.type == "image_url": elif content.type == "image_url":
# NOTE: Only works for llava # NOTE: works for llava and intervl2_5
real_content += image_token if conv.name == "internvl-2-5":
real_content = image_token + real_content
else:
real_content += image_token
conv.append_image(content.image_url.url) conv.append_image(content.image_url.url)
elif content.type == "audio_url": elif content.type == "audio_url":
real_content += audio_token real_content += audio_token
...@@ -703,6 +717,19 @@ register_conv_template( ...@@ -703,6 +717,19 @@ register_conv_template(
) )
) )
register_conv_template(
Conversation(
name="internvl-2-5",
system_template="<|im_start|>system\n{system_message}",
system_message="你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>\n",
stop_str=["<|im_end|>", "<|action_end|>"],
image_token="<image>",
)
)
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example # Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
register_conv_template( register_conv_template(
Conversation( Conversation(
......
...@@ -19,6 +19,7 @@ import warnings ...@@ -19,6 +19,7 @@ import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Type, Union from typing import Dict, Optional, Type, Union
import transformers
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
...@@ -26,6 +27,7 @@ from transformers import ( ...@@ -26,6 +27,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
PretrainedConfig, PretrainedConfig,
PreTrainedTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
) )
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
...@@ -38,6 +40,7 @@ from sglang.srt.configs import ( ...@@ -38,6 +40,7 @@ from sglang.srt.configs import (
KimiVLConfig, KimiVLConfig,
MultiModalityConfig, MultiModalityConfig,
) )
from sglang.srt.configs.internvl import InternVLChatConfig
from sglang.srt.connector import create_remote_connector from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url from sglang.srt.utils import is_remote_url
...@@ -48,6 +51,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ...@@ -48,6 +51,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
DeepseekVL2Config.model_type: DeepseekVL2Config, DeepseekVL2Config.model_type: DeepseekVL2Config,
MultiModalityConfig.model_type: MultiModalityConfig, MultiModalityConfig.model_type: MultiModalityConfig,
KimiVLConfig.model_type: KimiVLConfig, KimiVLConfig.model_type: KimiVLConfig,
InternVLChatConfig.model_type: InternVLChatConfig,
} }
for name, cls in _CONFIG_REGISTRY.items(): for name, cls in _CONFIG_REGISTRY.items():
...@@ -90,6 +94,12 @@ def get_config( ...@@ -90,6 +94,12 @@ def get_config(
config = config_class.from_pretrained(model, revision=revision) config = config_class.from_pretrained(model, revision=revision)
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`. # NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
setattr(config, "_name_or_path", model) setattr(config, "_name_or_path", model)
if isinstance(model, str) and config.model_type == "internvl_chat":
for key, val in config.llm_config.__dict__.items():
if not hasattr(config, key):
setattr(config, key, val)
if model_override_args: if model_override_args:
config.update(model_override_args) config.update(model_override_args)
...@@ -211,6 +221,13 @@ def get_tokenizer( ...@@ -211,6 +221,13 @@ def get_tokenizer(
return tokenizer return tokenizer
# Some models doesn't have an available processor, e.g.: InternVL
def get_tokenizer_from_processor(processor):
if isinstance(processor, PreTrainedTokenizerBase):
return processor
return processor.tokenizer
def get_processor( def get_processor(
tokenizer_name: str, tokenizer_name: str,
*args, *args,
...@@ -246,7 +263,9 @@ def get_processor( ...@@ -246,7 +263,9 @@ def get_processor(
**kwargs, **kwargs,
) )
attach_additional_stop_token_ids(processor.tokenizer) tokenizer = get_tokenizer_from_processor(processor)
attach_additional_stop_token_ids(tokenizer)
return processor return processor
......
# Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
import numpy as np
import torch
from decord import VideoReader, cpu
from numpy.distutils.cpuinfo import cpu
from PIL import Image
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.internvl import InternVLChatModel
class InternVLImageProcessor(BaseMultimodalProcessor):
models = [InternVLChatModel]
def __init__(self, hf_config, server_args, _image_processor):
super().__init__(hf_config, server_args, _image_processor)
image_size = hf_config.force_image_size or hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
self.IMG_START_TOKEN = "<img>"
self.IMG_END_TOKEN = "</img>"
self.IMG_TOKEN = "<image>"
self.num_image_token = int(
(image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
)
tokenizer = self._processor
self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
self.img_context_token_id = tokenizer.convert_tokens_to_ids(
self.IMG_CONTEXT_TOKEN
)
@staticmethod
def build_transform(input_size):
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def resize_image(img, size):
return img.resize((size, size), Image.Resampling.BICUBIC)
def to_tensor(img):
# Convert PIL Image to numpy array
img_array = np.array(img).astype(np.float32) / 255.0
# Convert HWC to CHW format
img_array = img_array.transpose(2, 0, 1)
return torch.from_numpy(img_array)
def normalize(tensor, mean, std):
mean = torch.tensor(mean).view(-1, 1, 1)
std = torch.tensor(std).view(-1, 1, 1)
return (tensor - mean) / std
def transform(img):
img = img.convert("RGB") if img.mode != "RGB" else img
img = resize_image(img, input_size)
tensor = to_tensor(img)
tensor = normalize(tensor, IMAGENET_MEAN, IMAGENET_STD)
return tensor
return transform
@staticmethod
def dynamic_preprocess(
image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
):
def find_closest_aspect_ratio(
aspect_ratio, target_ratios, width, height, image_size
):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
@staticmethod
def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
if bound:
start, end = bound[0], bound[1]
else:
start, end = -100000, 100000
start_idx = max(first_idx, round(start * fps))
end_idx = min(round(end * fps), max_frame)
seg_size = float(end_idx - start_idx) / num_segments
frame_indices = np.array(
[
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
for idx in range(num_segments)
]
)
return frame_indices
@staticmethod
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
max_frame = len(vr) - 1
fps = float(vr.get_avg_fps())
pixel_values_list, num_patches_list = [], []
transform = InternVLImageProcessor.build_transform(input_size=input_size)
frame_indices = InternVLImageProcessor.get_index(
bound, fps, max_frame, first_idx=0, num_segments=num_segments
)
for frame_index in frame_indices:
img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
img = InternVLImageProcessor.dynamic_preprocess(
img, image_size=input_size, use_thumbnail=True, max_num=max_num
)
pixel_values = [transform(tile) for tile in img]
pixel_values = torch.stack(pixel_values)
num_patches_list.append(pixel_values.shape[0])
pixel_values_list.append(pixel_values)
pixel_values = torch.cat(pixel_values_list)
return pixel_values, num_patches_list
async def process_mm_data_async(
self, image_data, input_text, request_obj, max_req_input_len, **kwargs
):
if not image_data:
return None
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMG_TOKEN),
max_req_input_len=max_req_input_len,
discard_alpha_channel=True,
)
def process_image_internvl(image, input_size=448, max_num=12):
transform = InternVLImageProcessor.build_transform(input_size=input_size)
images = InternVLImageProcessor.dynamic_preprocess(
image, image_size=input_size, use_thumbnail=True, max_num=max_num
)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
num_patches_list = []
pixel_values = []
# Process each input with allocated frames
for image_index, (image) in enumerate(base_output.images):
try:
# TODO: video input
raw_image = process_image_internvl(image)
pixel_value = [raw_image.to(torch.bfloat16).cuda()]
pixel_values += pixel_value
num_patches = raw_image.shape[0]
num_patches_list += [num_patches]
except FileNotFoundError as e:
print(e)
return None
pixel_values = torch.cat(pixel_values, dim=0)
items = [MultimodalDataItem(pixel_values=pixel_values, modality=Modality.IMAGE)]
for idx, num_patches in enumerate(num_patches_list):
image_tokens = (
self.IMG_START_TOKEN
+ self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
+ self.IMG_END_TOKEN
)
input_text = input_text.replace("<image>", image_tokens, 1)
tokenizer = self._processor
return {
"input_ids": tokenizer(input_text, return_tensors="pt")["input_ids"]
.flatten()
.tolist(),
"mm_items": items,
"im_start_id": self.img_start_token_id,
"im_end_id": self.img_end_token_id,
"im_token_id": self.img_context_token_id,
}
...@@ -52,7 +52,11 @@ from sglang.srt.disaggregation.utils import ( ...@@ -52,7 +52,11 @@ from sglang.srt.disaggregation.utils import (
TransferBackend, TransferBackend,
) )
from sglang.srt.distributed import get_pp_group, get_world_group from sglang.srt.distributed import get_pp_group, get_world_group
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import (
get_processor,
get_tokenizer,
get_tokenizer_from_processor,
)
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
...@@ -475,7 +479,7 @@ class Scheduler( ...@@ -475,7 +479,7 @@ class Scheduler(
revision=server_args.revision, revision=server_args.revision,
use_fast=not server_args.disable_fast_image_processor, use_fast=not server_args.disable_fast_image_processor,
) )
self.tokenizer = self.processor.tokenizer self.tokenizer = get_tokenizer_from_processor(self.processor)
else: else:
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
server_args.tokenizer_path, server_args.tokenizer_path,
......
...@@ -54,7 +54,11 @@ from sglang.srt.disaggregation.utils import ( ...@@ -54,7 +54,11 @@ from sglang.srt.disaggregation.utils import (
TransferBackend, TransferBackend,
get_kv_class, get_kv_class,
) )
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import (
get_processor,
get_tokenizer,
get_tokenizer_from_processor,
)
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOut, BatchEmbeddingOut,
...@@ -199,7 +203,7 @@ class TokenizerManager: ...@@ -199,7 +203,7 @@ class TokenizerManager:
self.tokenizer = self.processor = None self.tokenizer = self.processor = None
else: else:
self.processor = _processor self.processor = _processor
self.tokenizer = self.processor.tokenizer self.tokenizer = get_tokenizer_from_processor(self.processor)
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
else: else:
self.mm_processor = get_dummy_processor() self.mm_processor = get_dummy_processor()
......
...@@ -21,7 +21,11 @@ import torch ...@@ -21,7 +21,11 @@ import torch
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import (
get_processor,
get_tokenizer,
get_tokenizer_from_processor,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
...@@ -102,7 +106,7 @@ class TpModelWorker: ...@@ -102,7 +106,7 @@ class TpModelWorker:
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision, revision=server_args.revision,
) )
self.tokenizer = self.processor.tokenizer self.tokenizer = get_tokenizer_from_processor(self.processor)
else: else:
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
server_args.tokenizer_path, server_args.tokenizer_path,
......
...@@ -290,6 +290,9 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -290,6 +290,9 @@ class InternLM2ForCausalLM(nn.Module):
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
def get_input_embeddings(self) -> nn.Embedding:
return self.model.tok_embeddings
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
......
This diff is collapsed.
...@@ -604,6 +604,21 @@ class TestMinicpmvServer(TestOpenAIVisionServer): ...@@ -604,6 +604,21 @@ class TestMinicpmvServer(TestOpenAIVisionServer):
cls.base_url += "/v1" cls.base_url += "/v1"
class TestInternVL2_5Server(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "OpenGVLab/InternVL2_5-2B"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--trust-remote-code", "--chat-template", "internvl-2-5"],
)
cls.base_url += "/v1"
class TestMinicpmoServer(TestOpenAIVisionServer): class TestMinicpmoServer(TestOpenAIVisionServer):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment