Unverified Commit 9d02bb3e authored by Mick's avatar Mick Committed by GitHub
Browse files

Urgent model support: support gemma-3-it (#4424)

parent 402db5c5
......@@ -32,6 +32,7 @@
- Phi-3-Small
- IBM Granite 3
- Janus-Pro-1B / Janus-Pro-7B
- Gemma 3 (it)
## Embedding Models
......
......@@ -520,6 +520,14 @@ def match_granite_instruct(model_path: str):
return get_chat_template("granite-3-instruct")
@register_chat_template_matching_function
def match_gemma3_instruct(model_path: str):
model_path = model_path.lower()
if "gemma-3" in model_path and "1b" not in model_path:
# gemma-3-1b-it is completion model
return get_chat_template("gemma-it")
if __name__ == "__main__":
messages = [
{"role": "system", "content": None}, # None means default
......
from sglang.srt.configs.chatglm import ChatGLMConfig
from sglang.srt.configs.dbrx import DbrxConfig
from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.gemma3 import Gemma3Config, Gemma3TextConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.qwen2_5_vl_config import (
Qwen2_5_VLConfig,
......@@ -14,4 +15,6 @@ __all__ = [
"Qwen2_5_VLConfig",
"Qwen2_5_VLVisionConfig",
"MultiModalityConfig",
"Gemma3Config",
"Gemma3TextConfig",
]
This diff is collapsed.
......@@ -391,9 +391,13 @@ def _get_and_verify_dtype(
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
if config.model_type == "gemma2":
if config.model_type.startswith("gemma"):
if config.model_type == "gemma":
gemma_version = ""
else:
gemma_version = config.model_type[5]
logger.info(
"For Gemma 2, we downcast float32 to bfloat16 instead "
f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16."
)
......@@ -453,6 +457,7 @@ multimodal_model_archs = [
"LlavaQwenForCausalLM",
"LlavaMistralForCausalLM",
"LlavaVidForCausalLM",
"Gemma3ForConditionalGeneration",
"Grok1VForCausalLM",
"Grok1AForCausalLM",
"MllamaForConditionalGeneration",
......
......@@ -45,6 +45,7 @@ class SeparatorStyle(IntEnum):
DEEPSEEK_CHAT = auto()
METAMATH = auto()
QWEN2_VL_EMBED = auto()
GEMMA3 = auto()
@dataclasses.dataclass
......@@ -285,6 +286,18 @@ class Conversation:
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.GEMMA3:
ret = system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
if i == 0:
ret += message + self.sep
else:
ret += role + message + self.sep
else:
ret += role
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
......@@ -604,6 +617,20 @@ register_conv_template(
)
)
# Reference: https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json
register_conv_template(
Conversation(
name="gemma-it",
system_message="You are a helpful assistant.",
system_template="<bos><start_of_turn>user{system_message}\n\n",
roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
sep="<end_of_turn>\n",
sep_style=SeparatorStyle.GEMMA3,
stop_str=["<end_of_turn>"],
image_token="<start_of_image>",
)
)
# Reference: https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct#usage
register_conv_template(
Conversation(
......
......@@ -34,6 +34,8 @@ from sglang.srt.configs import (
ChatGLMConfig,
DbrxConfig,
ExaoneConfig,
Gemma3Config,
Gemma3TextConfig,
MultiModalityConfig,
Qwen2_5_VLConfig,
)
......@@ -46,6 +48,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ExaoneConfig.model_type: ExaoneConfig,
Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
MultiModalityConfig.model_type: MultiModalityConfig,
Gemma3Config.model_type: Gemma3Config,
Gemma3TextConfig.model_type: Gemma3TextConfig,
}
for name, cls in _CONFIG_REGISTRY.items():
......
......@@ -19,34 +19,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, rotate_half
from sglang.srt.utils import add_prefix
# Copied from transformers, modeling_qwen2_vl.py
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
class VisionAttention(nn.Module):
r"""
Multi-headed attention without any cache, mostly used for ViT.
......@@ -168,7 +144,7 @@ class VisionAttention(nn.Module):
cos, sin = position_embeddings
original_shape = q.shape
q, k = q.view(s, head, -1), k.view(s, head, -1)
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q, k = q.reshape(original_shape), k.reshape(original_shape)
if self.use_qkv_parallel:
......
......@@ -119,6 +119,26 @@ class GemmaRMSNorm(CustomOp):
return out
class Gemma3RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
if not _is_cuda:
logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
......
......@@ -1173,6 +1173,37 @@ def get_rope(
return rotary_emb
# Copied from transformers
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
unsqueeze_dim=1,
) -> Tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
# embedding is performed in float
cos = cos.unsqueeze(unsqueeze_dim).float()
sin = sin.unsqueeze(unsqueeze_dim).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
def get_rope_cpu(
head_size: int,
rotary_dim: int,
......
......@@ -111,7 +111,7 @@ class BaseImageProcessor(ABC):
def load_images(
self,
input_ids: list,
input_ids: list[int],
image_data,
image_token: str,
max_req_input_len: int,
......@@ -122,22 +122,21 @@ class BaseImageProcessor(ABC):
Each frame of video/image will be replaced by a single image token
Args:
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
image_hashes, image_sizes = [], []
all_frames = []
new_text_parts = []
if isinstance(input_ids, list) and return_text:
assert len(input_ids) and isinstance(input_ids[0], int)
input_text = self._processor.tokenizer.decode(input_ids)
else:
input_text = input_ids
if return_text:
text_parts = input_text.split(image_token)
import re
pattern = "(" + "|".join(re.escape(sep) for sep in [image_token]) + ")"
# split text into list of normal text and special tokens
text_parts = re.split(pattern, input_text)
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30
......@@ -145,53 +144,65 @@ class BaseImageProcessor(ABC):
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
_scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
assert len(image_data) == len(estimated_frames_list)
# Process each input with allocated frames
for image_index, (image, estimated_frames) in enumerate(
zip(image_data, estimated_frames_list)
):
if len(all_frames) >= MAX_NUM_FRAMES:
max_frames_to_process = 0
else:
max_frames_to_process = max(1, int(estimated_frames * scaling_factor))
if max_frames_to_process == 0:
frames = []
else:
try:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
frames = BaseImageProcessor.encode_video(
path, frame_count_limit=max_frames_to_process
)
image_index, audio_index = 0, 0
hashes, image_sizes, images, audios = [], [], [], []
new_text = ""
for index, text_part in enumerate(text_parts):
try:
if text_part == image_token:
# load as image
frames_to_process = estimated_frames_list[image_index]
if frames_to_process == 0:
frames = []
else:
raw_image, _size = load_image(image)
if discard_alpha_channel:
raw_image = raw_image.convert("RGB")
frames = [raw_image]
assert len(frames) != 0
except FileNotFoundError as e:
print(e)
return None
image_sizes += [frames[0].size] * len(frames)
image_hashes += [hash(image)] * len(frames)
all_frames += frames
if return_text:
new_text_parts.append(text_parts[image_index])
if max_frames_to_process != 0:
new_text_parts.append(image_token * len(frames))
assert max_frames_to_process >= len(frames)
if return_text:
new_text_parts.append(text_parts[-1])
image_file = image_data[image_index]
if isinstance(image_file, str) and image_file.startswith(
"video:"
):
# video
path = image_file[len("video:") :]
frames = self.encode_video(
path, frame_count_limit=frames_to_process
)
else:
# image
raw_image, _size = load_image(image_file)
if discard_alpha_channel:
raw_image = raw_image.convert("RGB")
frames = [raw_image]
if len(frames) == 0:
continue
image_sizes += frames[0].size * len(frames)
hashes += [hash(image_file)] * len(frames)
images += frames
image_index += 1
if frames_to_process != 0:
new_text += image_token * len(frames)
assert frames_to_process == len(frames)
else:
# TODO(mick): handle video
# normal text
new_text += text_part
except Exception as e:
import openai
logger.error(f"An exception occurred while loading images: {e}")
raise BadRequestError(
f"An exception occurred while loading images: {e}"
)
continue
input_text = "".join(new_text_parts)
return BaseImageProcessorOutput(
image_hashes, image_sizes, all_frames, input_text
image_hashes=hashes,
image_sizes=image_sizes,
all_frames=images,
input_text=new_text,
)
......
import asyncio
from typing import List, Union
from transformers.utils import logging
from sglang.srt.managers.image_processor import (
BaseImageProcessor as SGLangBaseImageProcessor,
)
from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor,
)
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
# will be removed in the future
logger = logging.get_logger(__name__)
class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<start_of_image>"
self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
@staticmethod
def _process_images_task(images, input_text, _hf_config):
if isinstance(images, list) and len(images) == 0:
images = None
processor = get_global_processor()
result = processor.__call__(
text=[input_text],
images=images,
padding=True,
return_tensors="pt",
# if RGBA, this needs to be set
# images_kwargs={
# "input_data_format": ChannelDimension.FIRST
# }
)
pixel_values = getattr(result, "pixel_values", None)
return {
"input_ids": result.input_ids,
"pixel_values": pixel_values,
}
async def _process_images(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Gemma3SGLangImageProcessor._process_images_task,
images,
input_text,
self.hf_config,
)
else:
return self._process_images_task(images, input_text, self.hf_config)
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
input_ids=input_ids,
image_data=image_data,
image_token=image_token,
max_req_input_len=max_req_input_len,
discard_alpha_channel=True,
)
ret = await self._process_images(
input_text=base_output.input_text, images=base_output.all_frames
)
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"image_hashes": base_output.image_hashes,
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
}
ImageProcessorMapping = {
Gemma3ForConditionalGeneration: Gemma3SGLangImageProcessor,
}
......@@ -60,7 +60,10 @@ class JanusProProcessor(SGLangBaseImageProcessor):
image_data = [image_data]
base_out = self.load_images(
input_ids, image_data, "<image_placeholder>", max_req_input_len
input_ids=input_ids,
image_data=image_data,
image_token="<image_placeholder>",
max_req_input_len=max_req_input_len,
)
images = base_out.all_frames
res = await self._process_images(images=images, input_text=base_out.input_text)
......
......@@ -52,7 +52,10 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
image_data = [image_data]
base_output = self.load_images(
input_ids, image_data, self.IMAGE_TOKEN, max_req_input_len
input_ids=input_ids,
image_data=image_data,
image_token=self.IMAGE_TOKEN,
max_req_input_len=max_req_input_len,
)
if base_output is None:
return None
......
......@@ -72,10 +72,10 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
input_ids,
image_data,
image_token,
max_req_input_len,
input_ids=input_ids,
image_data=image_data,
image_token=image_token,
max_req_input_len=max_req_input_len,
)
def smart_resize(
......
......@@ -49,7 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_compiler_backend, next_power_of_2
from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
......@@ -207,6 +207,9 @@ class ImageInputs:
return ret
def merge(self, other):
"""
merge image inputs when requests are being merged
"""
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
......
......@@ -33,6 +33,7 @@ from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional, Union
import numpy as np
import torch
import triton
import triton.language as tl
......@@ -331,6 +332,32 @@ class ForwardBatch:
return ret
def get_merged_image_inputs(self) -> Optional[ImageInputs]:
"""
Merge all image inputs in the batch into a single ImageInputs object.
Returns:
if none, current batch contains no image input
"""
if not self.image_inputs or all(x is None for x in self.image_inputs):
return None
# Filter out None values
valid_inputs = [x for x in self.image_inputs if x is not None]
# Start with the first valid image input
merged = valid_inputs[0]
# Merge remaining inputs
for img_input in valid_inputs[1:]:
merged.merge(img_input)
if isinstance(merged.pixel_values, np.ndarray):
merged.pixel_values = torch.from_numpy(merged.pixel_values)
return merged
def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
......
This diff is collapsed.
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from:
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
import logging
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel
from sglang.srt.configs import Gemma3Config
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import Gemma3RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import (
MultiModalityDataPaddingPatternTokenPairs,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
cached_get_processor = lru_cache(get_processor)
class Gemma3ImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class Gemma3MultiModalProjector(nn.Module):
"""Projector for Gemma3 multimodal."""
def __init__(self, config: Gemma3Config):
super().__init__()
self.mm_input_projection_weight = nn.Parameter(
torch.zeros(
config.vision_config.hidden_size, config.text_config.hidden_size
)
)
self.mm_soft_emb_norm = Gemma3RMSNorm(
config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
)
self.patches_per_image = int(
config.vision_config.image_size // config.vision_config.patch_size
)
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
self.kernel_size = self.patches_per_image // self.tokens_per_side
self.avg_pool = nn.AvgPool2d(
kernel_size=self.kernel_size, stride=self.kernel_size
)
def forward(self, vision_outputs: torch.Tensor) -> torch.Tensor:
batch_size, seq_length, hidden_size = vision_outputs.shape
# Reshape for pooling
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
batch_size, hidden_size, self.patches_per_image, self.patches_per_image
)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
# Apply pooling
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
# Apply normalization
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
# Project to text embedding space
projected_vision_outputs = torch.matmul(
normed_vision_outputs, self.mm_input_projection_weight
)
return projected_vision_outputs.type_as(vision_outputs)
class Gemma3ForConditionalGeneration(PreTrainedModel):
config_class = Gemma3Config
"""Gemma3 multimodal model for conditional generation."""
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
supports_lora = True
def __init__(
self,
config: Gemma3Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config)
self.config = config
self.quant_config = quant_config
# Vision components
# TODO: replace with vision attention
# self.vision_tower = SiglipVisionModel(
# config.vision_config,
# quant_config,
# prefix=add_prefix("vision_tower", prefix),
# )
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = Gemma3MultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
# Text model
self.language_model = Gemma3ForCausalLM(
config.text_config, quant_config, prefix=add_prefix("model", prefix)
)
if self.language_model.logits_processor.logit_scale:
logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.logit_scale *= logit_scale
self.post_init()
def pad_input_ids(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""Pad input IDs with image tokens."""
# Get special token IDs
im_start_id: int = image_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
ids = pattern.pad_input_tokens(input_ids, image_inputs)
return ids
def prepare_attn_masks(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mask_dtype: torch.dtype,
**kwargs,
) -> Dict:
"""Prepare attention masks for multimodal inputs."""
kwargs["has_images"] = True
# Distinguish sequences by position id 0
start_indices = (positions == 0).cpu().nonzero()
num_seqs = len(start_indices)
seq_lens = []
for i in range(num_seqs):
start_idx = start_indices[i].item()
if i < num_seqs - 1:
end_idx = start_indices[i + 1].item()
else:
end_idx = len(input_ids)
seq_lens.append(end_idx - start_idx)
kwargs["seq_lens"] = seq_lens
# Create attention masks
global_attn_masks = []
local_attn_masks = []
sliding_window = self.config.text_config.interleaved_sliding_window
start_idx = 0
for seq_len in seq_lens:
end_idx = start_idx + seq_len
input_token_ids = input_ids[start_idx:end_idx]
start_idx = end_idx
# Create global causal mask
global_attn_mask = torch.empty(
1,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
global_attn_mask = global_attn_mask.triu(diagonal=1)
# Consider bidirectional attention between image tokens
img_mask = torch.zeros_like(global_attn_mask)
img_pos = input_token_ids == self.config.image_token_index
img_mask[:, :, :, img_pos] += 1
img_mask[:, :, img_pos, :] += 1
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)
# Create local causal mask with sliding window
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
local_attn_mask = torch.where(
local_attn_mask == 0, global_attn_mask, float("-inf")
)
local_attn_masks.append(local_attn_mask)
kwargs["global_attn_masks"] = global_attn_masks
kwargs["local_attn_masks"] = local_attn_masks
return kwargs
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def get_image_features(self, pixel_values: torch.Tensor):
"""
Projects the last hidden state from the vision model into language model space.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
pixel_values = pixel_values.to("cuda")
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
image_features = self.multi_modal_projector(vision_outputs)
return image_features
def embed_image_inputs(
self,
input_ids: torch.Tensor,
forward_batch: ForwardBatch,
image_input: ImageInputs,
) -> torch.Tensor:
if input_ids is None:
raise ValueError("Unimplemented")
# boolean-masking image tokens
special_image_mask = torch.isin(
input_ids,
torch.tensor(image_input.pad_values, device=input_ids.device),
).unsqueeze(-1)
num_image_tokens_in_input_ids = special_image_mask.sum()
inputs_embeds = None
if num_image_tokens_in_input_ids == 0:
inputs_embeds = self.get_input_embeddings()(input_ids)
return inputs_embeds
else:
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
image_features = self.get_image_features(image_input.pixel_values)
# print(f"image tokens from image embeddings: {image_features.numel()}")
num_image_tokens_in_embedding = (
image_features.shape[0] * image_features.shape[1]
)
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
num_image = num_image_tokens_in_input_ids // image_features.shape[1]
image_features = image_features[:num_image, :]
logger.warning(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
"tokens from image embeddings."
)
# Important: clamp after extracting original image boundaries
input_ids.clamp_(min=0, max=self.vocab_size - 1)
inputs_embeds = self.get_input_embeddings()(input_ids)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
image_features = image_features.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask, image_features
)
return inputs_embeds
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
**kwargs: object,
) -> LogitsProcessor:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
>>> prompt = "answer en Where is the cow standing?"
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"answer en Where is the cow standing?\nbeach"
```"""
# Important: position_ids in Gemma3 are 1-indexed
# This really does cost me sometime
positions += 1
# Replace image id with PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_index
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
merged_image_input = forward_batch.get_merged_image_inputs()
if (
not forward_batch.forward_mode.is_decode()
and merged_image_input is not None
):
inputs_embeds = self.embed_image_inputs(
input_ids=llm_input_ids,
forward_batch=forward_batch,
image_input=merged_image_input,
)
else:
llm_input_ids.clamp_(min=0, max=self.vocab_size - 1)
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
outputs = self.language_model(
input_ids=None,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
**kwargs,
)
return outputs
def tie_weights(self):
return self.language_model.tie_weights()
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""Load weights for the model."""
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "language_model" in name:
# Gemma3ForCausalLM.load_weights(self, [(name.replace("language_model.", ""), loaded_weight)])
causal_loaded_params = Gemma3ForCausalLM.load_weights(
self, [(name, loaded_weight)]
)
loaded_params.update(causal_loaded_params)
continue
else:
# Skip lm_head.weight as it's tied with embed_tokens
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
pass
# raise RuntimeError(
# f"Some weights are not initialized from checkpoints: {unloaded_params}")
return loaded_params
EntryClass = Gemma3ForConditionalGeneration
AutoModel.register(Gemma3Config, Gemma3ForConditionalGeneration, exist_ok=True)
......@@ -41,7 +41,6 @@ from functools import lru_cache
from importlib.metadata import PackageNotFoundError, version
from importlib.util import find_spec
from io import BytesIO
from multiprocessing import Pool
from multiprocessing.reduction import ForkingPickler
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
......@@ -454,8 +453,9 @@ def load_image(image_file: Union[str, bytes]):
image = Image.open(BytesIO(image_file))
elif image_file.startswith("http://") or image_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
response = requests.get(image_file, timeout=timeout)
image = Image.open(BytesIO(response.content))
response = requests.get(image_file, stream=True, timeout=timeout).raw
image = Image.open(response)
response.close()
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
image = Image.open(image_file)
elif image_file.startswith("data:"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment