Unverified Commit 9d02bb3e authored by Mick's avatar Mick Committed by GitHub
Browse files

Urgent model support: support gemma-3-it (#4424)

parent 402db5c5
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
- Phi-3-Small - Phi-3-Small
- IBM Granite 3 - IBM Granite 3
- Janus-Pro-1B / Janus-Pro-7B - Janus-Pro-1B / Janus-Pro-7B
- Gemma 3 (it)
## Embedding Models ## Embedding Models
......
...@@ -520,6 +520,14 @@ def match_granite_instruct(model_path: str): ...@@ -520,6 +520,14 @@ def match_granite_instruct(model_path: str):
return get_chat_template("granite-3-instruct") return get_chat_template("granite-3-instruct")
@register_chat_template_matching_function
def match_gemma3_instruct(model_path: str):
model_path = model_path.lower()
if "gemma-3" in model_path and "1b" not in model_path:
# gemma-3-1b-it is completion model
return get_chat_template("gemma-it")
if __name__ == "__main__": if __name__ == "__main__":
messages = [ messages = [
{"role": "system", "content": None}, # None means default {"role": "system", "content": None}, # None means default
......
from sglang.srt.configs.chatglm import ChatGLMConfig from sglang.srt.configs.chatglm import ChatGLMConfig
from sglang.srt.configs.dbrx import DbrxConfig from sglang.srt.configs.dbrx import DbrxConfig
from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.gemma3 import Gemma3Config, Gemma3TextConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.qwen2_5_vl_config import ( from sglang.srt.configs.qwen2_5_vl_config import (
Qwen2_5_VLConfig, Qwen2_5_VLConfig,
...@@ -14,4 +15,6 @@ __all__ = [ ...@@ -14,4 +15,6 @@ __all__ = [
"Qwen2_5_VLConfig", "Qwen2_5_VLConfig",
"Qwen2_5_VLVisionConfig", "Qwen2_5_VLVisionConfig",
"MultiModalityConfig", "MultiModalityConfig",
"Gemma3Config",
"Gemma3TextConfig",
] ]
This diff is collapsed.
...@@ -391,9 +391,13 @@ def _get_and_verify_dtype( ...@@ -391,9 +391,13 @@ def _get_and_verify_dtype(
dtype = dtype.lower() dtype = dtype.lower()
if dtype == "auto": if dtype == "auto":
if config_dtype == torch.float32: if config_dtype == torch.float32:
if config.model_type == "gemma2": if config.model_type.startswith("gemma"):
if config.model_type == "gemma":
gemma_version = ""
else:
gemma_version = config.model_type[5]
logger.info( logger.info(
"For Gemma 2, we downcast float32 to bfloat16 instead " f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you " "of float16 by default. Please specify `dtype` if you "
"want to use float16." "want to use float16."
) )
...@@ -453,6 +457,7 @@ multimodal_model_archs = [ ...@@ -453,6 +457,7 @@ multimodal_model_archs = [
"LlavaQwenForCausalLM", "LlavaQwenForCausalLM",
"LlavaMistralForCausalLM", "LlavaMistralForCausalLM",
"LlavaVidForCausalLM", "LlavaVidForCausalLM",
"Gemma3ForConditionalGeneration",
"Grok1VForCausalLM", "Grok1VForCausalLM",
"Grok1AForCausalLM", "Grok1AForCausalLM",
"MllamaForConditionalGeneration", "MllamaForConditionalGeneration",
......
...@@ -45,6 +45,7 @@ class SeparatorStyle(IntEnum): ...@@ -45,6 +45,7 @@ class SeparatorStyle(IntEnum):
DEEPSEEK_CHAT = auto() DEEPSEEK_CHAT = auto()
METAMATH = auto() METAMATH = auto()
QWEN2_VL_EMBED = auto() QWEN2_VL_EMBED = auto()
GEMMA3 = auto()
@dataclasses.dataclass @dataclasses.dataclass
...@@ -285,6 +286,18 @@ class Conversation: ...@@ -285,6 +286,18 @@ class Conversation:
else: else:
ret += role + ":" ret += role + ":"
return ret return ret
elif self.sep_style == SeparatorStyle.GEMMA3:
ret = system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
if i == 0:
ret += message + self.sep
else:
ret += role + message + self.sep
else:
ret += role
return ret
else: else:
raise ValueError(f"Invalid style: {self.sep_style}") raise ValueError(f"Invalid style: {self.sep_style}")
...@@ -604,6 +617,20 @@ register_conv_template( ...@@ -604,6 +617,20 @@ register_conv_template(
) )
) )
# Reference: https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json
register_conv_template(
Conversation(
name="gemma-it",
system_message="You are a helpful assistant.",
system_template="<bos><start_of_turn>user{system_message}\n\n",
roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
sep="<end_of_turn>\n",
sep_style=SeparatorStyle.GEMMA3,
stop_str=["<end_of_turn>"],
image_token="<start_of_image>",
)
)
# Reference: https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct#usage # Reference: https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct#usage
register_conv_template( register_conv_template(
Conversation( Conversation(
......
...@@ -34,6 +34,8 @@ from sglang.srt.configs import ( ...@@ -34,6 +34,8 @@ from sglang.srt.configs import (
ChatGLMConfig, ChatGLMConfig,
DbrxConfig, DbrxConfig,
ExaoneConfig, ExaoneConfig,
Gemma3Config,
Gemma3TextConfig,
MultiModalityConfig, MultiModalityConfig,
Qwen2_5_VLConfig, Qwen2_5_VLConfig,
) )
...@@ -46,6 +48,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ...@@ -46,6 +48,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ExaoneConfig.model_type: ExaoneConfig, ExaoneConfig.model_type: ExaoneConfig,
Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig, Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
MultiModalityConfig.model_type: MultiModalityConfig, MultiModalityConfig.model_type: MultiModalityConfig,
Gemma3Config.model_type: Gemma3Config,
Gemma3TextConfig.model_type: Gemma3TextConfig,
} }
for name, cls in _CONFIG_REGISTRY.items(): for name, cls in _CONFIG_REGISTRY.items():
......
...@@ -19,34 +19,10 @@ from sglang.srt.layers.linear import ( ...@@ -19,34 +19,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, rotate_half
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
# Copied from transformers, modeling_qwen2_vl.py
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
class VisionAttention(nn.Module): class VisionAttention(nn.Module):
r""" r"""
Multi-headed attention without any cache, mostly used for ViT. Multi-headed attention without any cache, mostly used for ViT.
...@@ -168,7 +144,7 @@ class VisionAttention(nn.Module): ...@@ -168,7 +144,7 @@ class VisionAttention(nn.Module):
cos, sin = position_embeddings cos, sin = position_embeddings
original_shape = q.shape original_shape = q.shape
q, k = q.view(s, head, -1), k.view(s, head, -1) q, k = q.view(s, head, -1), k.view(s, head, -1)
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) q, k = apply_rotary_pos_emb(q, k, cos, sin)
q, k = q.reshape(original_shape), k.reshape(original_shape) q, k = q.reshape(original_shape), k.reshape(original_shape)
if self.use_qkv_parallel: if self.use_qkv_parallel:
......
...@@ -119,6 +119,26 @@ class GemmaRMSNorm(CustomOp): ...@@ -119,6 +119,26 @@ class GemmaRMSNorm(CustomOp):
return out return out
class Gemma3RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
if not _is_cuda: if not _is_cuda:
logger.info( logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
......
...@@ -1173,6 +1173,37 @@ def get_rope( ...@@ -1173,6 +1173,37 @@ def get_rope(
return rotary_emb return rotary_emb
# Copied from transformers
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
unsqueeze_dim=1,
) -> Tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
# embedding is performed in float
cos = cos.unsqueeze(unsqueeze_dim).float()
sin = sin.unsqueeze(unsqueeze_dim).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
def get_rope_cpu( def get_rope_cpu(
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
......
...@@ -111,7 +111,7 @@ class BaseImageProcessor(ABC): ...@@ -111,7 +111,7 @@ class BaseImageProcessor(ABC):
def load_images( def load_images(
self, self,
input_ids: list, input_ids: list[int],
image_data, image_data,
image_token: str, image_token: str,
max_req_input_len: int, max_req_input_len: int,
...@@ -122,22 +122,21 @@ class BaseImageProcessor(ABC): ...@@ -122,22 +122,21 @@ class BaseImageProcessor(ABC):
Each frame of video/image will be replaced by a single image token Each frame of video/image will be replaced by a single image token
Args: Args:
discard_alpha_channel: if True, discards the alpha channel in the returned images discard_alpha_channel: if True, discards the alpha channel in the returned images
""" """
image_hashes, image_sizes = [], []
all_frames = []
new_text_parts = []
if isinstance(input_ids, list) and return_text: if isinstance(input_ids, list) and return_text:
assert len(input_ids) and isinstance(input_ids[0], int) assert len(input_ids) and isinstance(input_ids[0], int)
input_text = self._processor.tokenizer.decode(input_ids) input_text = self._processor.tokenizer.decode(input_ids)
else: else:
input_text = input_ids input_text = input_ids
if return_text: if return_text:
text_parts = input_text.split(image_token) import re
pattern = "(" + "|".join(re.escape(sep) for sep in [image_token]) + ")"
# split text into list of normal text and special tokens
text_parts = re.split(pattern, input_text)
# TODO(mick): load from server_args, env, or sampling_params # TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30 MAX_NUM_FRAMES = 30
...@@ -145,53 +144,65 @@ class BaseImageProcessor(ABC): ...@@ -145,53 +144,65 @@ class BaseImageProcessor(ABC):
total_frame_count = sum(estimated_frames_list) total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs. # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count) _scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
assert len(image_data) == len(estimated_frames_list) assert len(image_data) == len(estimated_frames_list)
# Process each input with allocated frames image_index, audio_index = 0, 0
for image_index, (image, estimated_frames) in enumerate( hashes, image_sizes, images, audios = [], [], [], []
zip(image_data, estimated_frames_list) new_text = ""
): for index, text_part in enumerate(text_parts):
if len(all_frames) >= MAX_NUM_FRAMES: try:
max_frames_to_process = 0 if text_part == image_token:
else: # load as image
max_frames_to_process = max(1, int(estimated_frames * scaling_factor)) frames_to_process = estimated_frames_list[image_index]
if frames_to_process == 0:
if max_frames_to_process == 0: frames = []
frames = []
else:
try:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
frames = BaseImageProcessor.encode_video(
path, frame_count_limit=max_frames_to_process
)
else: else:
raw_image, _size = load_image(image) image_file = image_data[image_index]
if discard_alpha_channel: if isinstance(image_file, str) and image_file.startswith(
raw_image = raw_image.convert("RGB") "video:"
frames = [raw_image] ):
assert len(frames) != 0 # video
except FileNotFoundError as e: path = image_file[len("video:") :]
print(e) frames = self.encode_video(
return None path, frame_count_limit=frames_to_process
)
image_sizes += [frames[0].size] * len(frames) else:
image_hashes += [hash(image)] * len(frames) # image
all_frames += frames raw_image, _size = load_image(image_file)
if discard_alpha_channel:
if return_text: raw_image = raw_image.convert("RGB")
new_text_parts.append(text_parts[image_index]) frames = [raw_image]
if max_frames_to_process != 0: if len(frames) == 0:
new_text_parts.append(image_token * len(frames)) continue
assert max_frames_to_process >= len(frames)
if return_text: image_sizes += frames[0].size * len(frames)
new_text_parts.append(text_parts[-1]) hashes += [hash(image_file)] * len(frames)
images += frames
image_index += 1
if frames_to_process != 0:
new_text += image_token * len(frames)
assert frames_to_process == len(frames)
else:
# TODO(mick): handle video
# normal text
new_text += text_part
except Exception as e:
import openai
logger.error(f"An exception occurred while loading images: {e}")
raise BadRequestError(
f"An exception occurred while loading images: {e}"
)
continue
input_text = "".join(new_text_parts)
return BaseImageProcessorOutput( return BaseImageProcessorOutput(
image_hashes, image_sizes, all_frames, input_text image_hashes=hashes,
image_sizes=image_sizes,
all_frames=images,
input_text=new_text,
) )
......
import asyncio
from typing import List, Union
from transformers.utils import logging
from sglang.srt.managers.image_processor import (
BaseImageProcessor as SGLangBaseImageProcessor,
)
from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor,
)
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
# will be removed in the future
logger = logging.get_logger(__name__)
class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<start_of_image>"
self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
@staticmethod
def _process_images_task(images, input_text, _hf_config):
if isinstance(images, list) and len(images) == 0:
images = None
processor = get_global_processor()
result = processor.__call__(
text=[input_text],
images=images,
padding=True,
return_tensors="pt",
# if RGBA, this needs to be set
# images_kwargs={
# "input_data_format": ChannelDimension.FIRST
# }
)
pixel_values = getattr(result, "pixel_values", None)
return {
"input_ids": result.input_ids,
"pixel_values": pixel_values,
}
async def _process_images(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Gemma3SGLangImageProcessor._process_images_task,
images,
input_text,
self.hf_config,
)
else:
return self._process_images_task(images, input_text, self.hf_config)
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
input_ids=input_ids,
image_data=image_data,
image_token=image_token,
max_req_input_len=max_req_input_len,
discard_alpha_channel=True,
)
ret = await self._process_images(
input_text=base_output.input_text, images=base_output.all_frames
)
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"image_hashes": base_output.image_hashes,
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
}
ImageProcessorMapping = {
Gemma3ForConditionalGeneration: Gemma3SGLangImageProcessor,
}
...@@ -60,7 +60,10 @@ class JanusProProcessor(SGLangBaseImageProcessor): ...@@ -60,7 +60,10 @@ class JanusProProcessor(SGLangBaseImageProcessor):
image_data = [image_data] image_data = [image_data]
base_out = self.load_images( base_out = self.load_images(
input_ids, image_data, "<image_placeholder>", max_req_input_len input_ids=input_ids,
image_data=image_data,
image_token="<image_placeholder>",
max_req_input_len=max_req_input_len,
) )
images = base_out.all_frames images = base_out.all_frames
res = await self._process_images(images=images, input_text=base_out.input_text) res = await self._process_images(images=images, input_text=base_out.input_text)
......
...@@ -52,7 +52,10 @@ class MiniCPMVImageProcessor(BaseImageProcessor): ...@@ -52,7 +52,10 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
image_data = [image_data] image_data = [image_data]
base_output = self.load_images( base_output = self.load_images(
input_ids, image_data, self.IMAGE_TOKEN, max_req_input_len input_ids=input_ids,
image_data=image_data,
image_token=self.IMAGE_TOKEN,
max_req_input_len=max_req_input_len,
) )
if base_output is None: if base_output is None:
return None return None
......
...@@ -72,10 +72,10 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor): ...@@ -72,10 +72,10 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
image_token = self.IMAGE_TOKEN image_token = self.IMAGE_TOKEN
base_output = self.load_images( base_output = self.load_images(
input_ids, input_ids=input_ids,
image_data, image_data=image_data,
image_token, image_token=image_token,
max_req_input_len, max_req_input_len=max_req_input_len,
) )
def smart_resize( def smart_resize(
......
...@@ -49,7 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw ...@@ -49,7 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_compiler_backend, next_power_of_2 from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
...@@ -207,6 +207,9 @@ class ImageInputs: ...@@ -207,6 +207,9 @@ class ImageInputs:
return ret return ret
def merge(self, other): def merge(self, other):
"""
merge image inputs when requests are being merged
"""
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:] assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values]) self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
......
...@@ -33,6 +33,7 @@ from dataclasses import dataclass ...@@ -33,6 +33,7 @@ from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional, Union from typing import TYPE_CHECKING, List, Optional, Union
import numpy as np
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -331,6 +332,32 @@ class ForwardBatch: ...@@ -331,6 +332,32 @@ class ForwardBatch:
return ret return ret
def get_merged_image_inputs(self) -> Optional[ImageInputs]:
"""
Merge all image inputs in the batch into a single ImageInputs object.
Returns:
if none, current batch contains no image input
"""
if not self.image_inputs or all(x is None for x in self.image_inputs):
return None
# Filter out None values
valid_inputs = [x for x in self.image_inputs if x is not None]
# Start with the first valid image input
merged = valid_inputs[0]
# Merge remaining inputs
for img_input in valid_inputs[1:]:
merged.merge(img_input)
if isinstance(merged.pixel_values, np.ndarray):
merged.pixel_values = torch.from_numpy(merged.pixel_values)
return merged
def _compute_mrope_positions( def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch self, model_runner: ModelRunner, batch: ModelWorkerBatch
): ):
......
This diff is collapsed.
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from:
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
import logging
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel
from sglang.srt.configs import Gemma3Config
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import Gemma3RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import (
MultiModalityDataPaddingPatternTokenPairs,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
cached_get_processor = lru_cache(get_processor)
class Gemma3ImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class Gemma3MultiModalProjector(nn.Module):
"""Projector for Gemma3 multimodal."""
def __init__(self, config: Gemma3Config):
super().__init__()
self.mm_input_projection_weight = nn.Parameter(
torch.zeros(
config.vision_config.hidden_size, config.text_config.hidden_size
)
)
self.mm_soft_emb_norm = Gemma3RMSNorm(
config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
)
self.patches_per_image = int(
config.vision_config.image_size // config.vision_config.patch_size
)
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
self.kernel_size = self.patches_per_image // self.tokens_per_side
self.avg_pool = nn.AvgPool2d(
kernel_size=self.kernel_size, stride=self.kernel_size
)
def forward(self, vision_outputs: torch.Tensor) -> torch.Tensor:
batch_size, seq_length, hidden_size = vision_outputs.shape
# Reshape for pooling
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
batch_size, hidden_size, self.patches_per_image, self.patches_per_image
)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
# Apply pooling
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
# Apply normalization
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
# Project to text embedding space
projected_vision_outputs = torch.matmul(
normed_vision_outputs, self.mm_input_projection_weight
)
return projected_vision_outputs.type_as(vision_outputs)
class Gemma3ForConditionalGeneration(PreTrainedModel):
config_class = Gemma3Config
"""Gemma3 multimodal model for conditional generation."""
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
supports_lora = True
def __init__(
self,
config: Gemma3Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config)
self.config = config
self.quant_config = quant_config
# Vision components
# TODO: replace with vision attention
# self.vision_tower = SiglipVisionModel(
# config.vision_config,
# quant_config,
# prefix=add_prefix("vision_tower", prefix),
# )
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = Gemma3MultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
# Text model
self.language_model = Gemma3ForCausalLM(
config.text_config, quant_config, prefix=add_prefix("model", prefix)
)
if self.language_model.logits_processor.logit_scale:
logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.logit_scale *= logit_scale
self.post_init()
def pad_input_ids(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""Pad input IDs with image tokens."""
# Get special token IDs
im_start_id: int = image_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
ids = pattern.pad_input_tokens(input_ids, image_inputs)
return ids
def prepare_attn_masks(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mask_dtype: torch.dtype,
**kwargs,
) -> Dict:
"""Prepare attention masks for multimodal inputs."""
kwargs["has_images"] = True
# Distinguish sequences by position id 0
start_indices = (positions == 0).cpu().nonzero()
num_seqs = len(start_indices)
seq_lens = []
for i in range(num_seqs):
start_idx = start_indices[i].item()
if i < num_seqs - 1:
end_idx = start_indices[i + 1].item()
else:
end_idx = len(input_ids)
seq_lens.append(end_idx - start_idx)
kwargs["seq_lens"] = seq_lens
# Create attention masks
global_attn_masks = []
local_attn_masks = []
sliding_window = self.config.text_config.interleaved_sliding_window
start_idx = 0
for seq_len in seq_lens:
end_idx = start_idx + seq_len
input_token_ids = input_ids[start_idx:end_idx]
start_idx = end_idx
# Create global causal mask
global_attn_mask = torch.empty(
1,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
global_attn_mask = global_attn_mask.triu(diagonal=1)
# Consider bidirectional attention between image tokens
img_mask = torch.zeros_like(global_attn_mask)
img_pos = input_token_ids == self.config.image_token_index
img_mask[:, :, :, img_pos] += 1
img_mask[:, :, img_pos, :] += 1
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)
# Create local causal mask with sliding window
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
local_attn_mask = torch.where(
local_attn_mask == 0, global_attn_mask, float("-inf")
)
local_attn_masks.append(local_attn_mask)
kwargs["global_attn_masks"] = global_attn_masks
kwargs["local_attn_masks"] = local_attn_masks
return kwargs
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def get_image_features(self, pixel_values: torch.Tensor):
"""
Projects the last hidden state from the vision model into language model space.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
pixel_values = pixel_values.to("cuda")
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
image_features = self.multi_modal_projector(vision_outputs)
return image_features
def embed_image_inputs(
self,
input_ids: torch.Tensor,
forward_batch: ForwardBatch,
image_input: ImageInputs,
) -> torch.Tensor:
if input_ids is None:
raise ValueError("Unimplemented")
# boolean-masking image tokens
special_image_mask = torch.isin(
input_ids,
torch.tensor(image_input.pad_values, device=input_ids.device),
).unsqueeze(-1)
num_image_tokens_in_input_ids = special_image_mask.sum()
inputs_embeds = None
if num_image_tokens_in_input_ids == 0:
inputs_embeds = self.get_input_embeddings()(input_ids)
return inputs_embeds
else:
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
image_features = self.get_image_features(image_input.pixel_values)
# print(f"image tokens from image embeddings: {image_features.numel()}")
num_image_tokens_in_embedding = (
image_features.shape[0] * image_features.shape[1]
)
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
num_image = num_image_tokens_in_input_ids // image_features.shape[1]
image_features = image_features[:num_image, :]
logger.warning(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
"tokens from image embeddings."
)
# Important: clamp after extracting original image boundaries
input_ids.clamp_(min=0, max=self.vocab_size - 1)
inputs_embeds = self.get_input_embeddings()(input_ids)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
image_features = image_features.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask, image_features
)
return inputs_embeds
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
**kwargs: object,
) -> LogitsProcessor:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
>>> prompt = "answer en Where is the cow standing?"
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"answer en Where is the cow standing?\nbeach"
```"""
# Important: position_ids in Gemma3 are 1-indexed
# This really does cost me sometime
positions += 1
# Replace image id with PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_index
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
merged_image_input = forward_batch.get_merged_image_inputs()
if (
not forward_batch.forward_mode.is_decode()
and merged_image_input is not None
):
inputs_embeds = self.embed_image_inputs(
input_ids=llm_input_ids,
forward_batch=forward_batch,
image_input=merged_image_input,
)
else:
llm_input_ids.clamp_(min=0, max=self.vocab_size - 1)
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
outputs = self.language_model(
input_ids=None,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
**kwargs,
)
return outputs
def tie_weights(self):
return self.language_model.tie_weights()
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""Load weights for the model."""
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "language_model" in name:
# Gemma3ForCausalLM.load_weights(self, [(name.replace("language_model.", ""), loaded_weight)])
causal_loaded_params = Gemma3ForCausalLM.load_weights(
self, [(name, loaded_weight)]
)
loaded_params.update(causal_loaded_params)
continue
else:
# Skip lm_head.weight as it's tied with embed_tokens
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
pass
# raise RuntimeError(
# f"Some weights are not initialized from checkpoints: {unloaded_params}")
return loaded_params
EntryClass = Gemma3ForConditionalGeneration
AutoModel.register(Gemma3Config, Gemma3ForConditionalGeneration, exist_ok=True)
...@@ -41,7 +41,6 @@ from functools import lru_cache ...@@ -41,7 +41,6 @@ from functools import lru_cache
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from importlib.util import find_spec from importlib.util import find_spec
from io import BytesIO from io import BytesIO
from multiprocessing import Pool
from multiprocessing.reduction import ForkingPickler from multiprocessing.reduction import ForkingPickler
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
...@@ -454,8 +453,9 @@ def load_image(image_file: Union[str, bytes]): ...@@ -454,8 +453,9 @@ def load_image(image_file: Union[str, bytes]):
image = Image.open(BytesIO(image_file)) image = Image.open(BytesIO(image_file))
elif image_file.startswith("http://") or image_file.startswith("https://"): elif image_file.startswith("http://") or image_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
response = requests.get(image_file, timeout=timeout) response = requests.get(image_file, stream=True, timeout=timeout).raw
image = Image.open(BytesIO(response.content)) image = Image.open(response)
response.close()
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")): elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
image = Image.open(image_file) image = Image.open(image_file)
elif image_file.startswith("data:"): elif image_file.startswith("data:"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment