Unverified Commit c0c25e25 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model] Add support for Gemma 3 (#14660)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 45f3f3f5
...@@ -263,10 +263,15 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -263,10 +263,15 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎ * ✅︎
* ✅︎ * ✅︎
- * `Gemma2ForCausalLM` - * `Gemma2ForCausalLM`
* Gemma2 * Gemma 2
* `google/gemma-2-9b`, `google/gemma-2-27b`, etc. * `google/gemma-2-9b`, `google/gemma-2-27b`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
- * `Gemma3ForCausalLM`
* Gemma 3
* `google/gemma-3-1b-it`, etc.
* ✅︎
* ✅︎
- * `GlmForCausalLM` - * `GlmForCausalLM`
* GLM-4 * GLM-4
* `THUDM/glm-4-9b-chat-hf`, etc. * `THUDM/glm-4-9b-chat-hf`, etc.
...@@ -504,7 +509,7 @@ you should explicitly specify the task type to ensure that the model is used in ...@@ -504,7 +509,7 @@ you should explicitly specify the task type to ensure that the model is used in
* *
* *
- * `Gemma2Model` - * `Gemma2Model`
* Gemma2-based * Gemma 2-based
* `BAAI/bge-multilingual-gemma2`, etc. * `BAAI/bge-multilingual-gemma2`, etc.
* *
* ✅︎ * ✅︎
...@@ -752,6 +757,13 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -752,6 +757,13 @@ See [this page](#generative-models) for more information on how to use generativ
* *
* ✅︎ * ✅︎
* ✅︎ * ✅︎
- * `Gemma3ForConditionalGeneration`
* Gemma 3
* T + I<sup>+</sup>
* `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc.
* ✅︎
* ✅︎
* ✅︎\*
- * `GLM4VForCausalLM`<sup>^</sup> - * `GLM4VForCausalLM`<sup>^</sup>
* GLM-4V * GLM-4V
* T + I * T + I
...@@ -937,6 +949,31 @@ For more details, please see: <gh-pr:4087#issuecomment-2250397630> ...@@ -937,6 +949,31 @@ For more details, please see: <gh-pr:4087#issuecomment-2250397630>
To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`. To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`.
::: :::
:::{note}
To use Gemma3 series models, you have to install Hugging Face Transformers library from source via
`pip install git+https://github.com/huggingface/transformers`.
The earliest commit that supports this is [`50d3530aa04e7a7d003e6b255a98f79fd0447357`](https://github.com/huggingface/transformers/commit/50d3530aa04e7a7d003e6b255a98f79fd0447357).
Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
However, there are differences in how they handle text + image inputs:
V0 correctly implements the model's attention pattern:
- Uses bidirectional attention between the image tokens corresponding to the same image
- Uses causal attention for other tokens
- Implemented via (naive) PyTorch SDPA with masking tensors
- Note: May use significant memory for long prompts with image
V1 currently uses a simplified attention pattern:
- Uses causal attention for all tokens, including image tokens
- Generates reasonable outputs but does not match the original model's attention for text + image inputs
- Will be updated in the future to support the correct behavior
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
Additionally, vLLM's current Gemma 3 implementation does not support the pan-and-scan image pre-processing algorithm, which helps handle images with skewed aspect ratios by intelligently cropping them into multiple views.
Without this feature, model performance may degrade when processing images that deviate significantly from square dimensions.
:::
### Pooling Models ### Pooling Models
See [this page](pooling-models) for more information on how to use pooling models. See [this page](pooling-models) for more information on how to use pooling models.
......
...@@ -118,6 +118,23 @@ def run_fuyu(questions: list[str], modality: str): ...@@ -118,6 +118,23 @@ def run_fuyu(questions: list[str], modality: str):
return llm, prompts, stop_token_ids return llm, prompts, stop_token_ids
# Gemma 3
def run_gemma3(questions: list[str], modality: str):
assert modality == "image"
model_name = "google/gemma-3-4b-it"
llm = LLM(model=model_name,
max_model_len=2048,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
prompts = [("<bos><start_of_turn>user\n"
f"<start_of_image>{question}<end_of_turn>\n"
"<start_of_turn>model\n") for question in questions]
stop_token_ids = None
return llm, prompts, stop_token_ids
# GLM-4v # GLM-4v
def run_glm4v(questions: list[str], modality: str): def run_glm4v(questions: list[str], modality: str):
assert modality == "image" assert modality == "image"
...@@ -405,7 +422,7 @@ def run_mllama(questions: list[str], modality: str): ...@@ -405,7 +422,7 @@ def run_mllama(questions: list[str], modality: str):
"type": "image" "type": "image"
}, { }, {
"type": "text", "type": "text",
"text": f"{question}" "text": question
}] }]
}] for question in questions] }] for question in questions]
prompts = tokenizer.apply_chat_template(messages, prompts = tokenizer.apply_chat_template(messages,
...@@ -664,6 +681,7 @@ model_example_map = { ...@@ -664,6 +681,7 @@ model_example_map = {
"deepseek_vl_v2": run_deepseek_vl2, "deepseek_vl_v2": run_deepseek_vl2,
"florence2": run_florence2, "florence2": run_florence2,
"fuyu": run_fuyu, "fuyu": run_fuyu,
"gemma3": run_gemma3,
"glm4v": run_glm4v, "glm4v": run_glm4v,
"h2ovl_chat": run_h2ovl, "h2ovl_chat": run_h2ovl,
"idefics3": run_idefics3, "idefics3": run_idefics3,
......
...@@ -80,6 +80,42 @@ def load_deepseek_vl2(question: str, image_urls: list[str]): ...@@ -80,6 +80,42 @@ def load_deepseek_vl2(question: str, image_urls: list[str]):
) )
def load_gemma3(question, image_urls: list[str]) -> ModelRequestData:
model_name = "google/gemma-3-4b-it"
llm = LLM(model=model_name,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)})
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=None,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData: def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "h2oai/h2ovl-mississippi-800m" model_name = "h2oai/h2ovl-mississippi-800m"
...@@ -496,6 +532,7 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData: ...@@ -496,6 +532,7 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData:
model_example_map = { model_example_map = {
"aria": load_aria, "aria": load_aria,
"deepseek_vl_v2": load_deepseek_vl2, "deepseek_vl_v2": load_deepseek_vl2,
"gemma3": load_gemma3,
"h2ovl_chat": load_h2ovl, "h2ovl_chat": load_h2ovl,
"idefics3": load_idefics3, "idefics3": load_idefics3,
"internvl_chat": load_internvl, "internvl_chat": load_internvl,
......
...@@ -162,6 +162,7 @@ def _test_processing_correctness( ...@@ -162,6 +162,7 @@ def _test_processing_correctness(
"deepseek-ai/deepseek-vl2-tiny", "deepseek-ai/deepseek-vl2-tiny",
"microsoft/Florence-2-base", "microsoft/Florence-2-base",
"adept/fuyu-8b", "adept/fuyu-8b",
"google/gemma-3-4b-it",
"THUDM/glm-4v-9b", "THUDM/glm-4v-9b",
"h2oai/h2ovl-mississippi-800m", "h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B", "OpenGVLab/InternVL2-1B",
......
...@@ -124,6 +124,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -124,6 +124,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it",
min_transformers_version="4.50"),
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"), "GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
"GPT2LMHeadModel": _HfExamplesInfo("gpt2"), "GPT2LMHeadModel": _HfExamplesInfo("gpt2"),
"GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"), "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"),
...@@ -241,6 +243,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -241,6 +243,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it",
min_transformers_version="4.50"),
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b", "GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
trust_remote_code=True, trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
......
...@@ -350,10 +350,11 @@ class ModelConfig: ...@@ -350,10 +350,11 @@ class ModelConfig:
if self.enforce_eager is None: if self.enforce_eager is None:
self.enforce_eager = False self.enforce_eager = False
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
sliding_window = getattr(self.hf_text_config, "sliding_window", None) sliding_window = getattr(self.hf_text_config, "sliding_window", None)
has_interleaved_attention = (sliding_window is not None) and ( has_interleaved_attention = (sliding_window is not None) and (
isinstance(sliding_window, list) or isinstance(sliding_window, list) or
(self.hf_text_config.model_type in ["gemma2", "cohere2"])) (self.hf_text_config.model_type in interleaved_attn_models))
if (not self.disable_sliding_window and has_interleaved_attention): if (not self.disable_sliding_window and has_interleaved_attention):
if (backend := if (backend :=
...@@ -2501,11 +2502,11 @@ def _get_and_verify_dtype( ...@@ -2501,11 +2502,11 @@ def _get_and_verify_dtype(
dtype = dtype.lower() dtype = dtype.lower()
if dtype == "auto": if dtype == "auto":
if config_dtype == torch.float32: if config_dtype == torch.float32:
if config.model_type == "gemma2": if config.model_type in ("gemma2", "gemma3", "gemma3_text"):
logger.info( logger.info(
"For Gemma 2, we downcast float32 to bfloat16 instead " "For Gemma 2 and 3, we downcast float32 to bfloat16 "
"of float16 by default. Please specify `dtype` if you " "instead of float16 by default. Please specify `dtype` "
"want to use float16.") "if you want to use float16.")
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
else: else:
# Following the common practice, we use float16 for float32 # Following the common practice, we use float16 for float32
...@@ -2637,7 +2638,9 @@ def _get_and_verify_max_len( ...@@ -2637,7 +2638,9 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None) rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None: # NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE
# scaling, so we skip applying the scaling factor again.
if rope_scaling is not None and "gemma3" not in hf_config.model_type:
# No need to consider "type" key because of patch_rope_scaling when # No need to consider "type" key because of patch_rope_scaling when
# loading HF config # loading HF config
rope_type = rope_scaling["rope_type"] rope_type = rope_scaling["rope_type"]
......
...@@ -433,6 +433,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -433,6 +433,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<image>" return "<image>"
if model_type == "aria": if model_type == "aria":
return "<|fim_prefix|><|img|><|fim_suffix|>" return "<|fim_prefix|><|img|><|fim_suffix|>"
if model_type == "gemma3":
return "<start_of_image>"
raise TypeError(f"Unknown {modality} model type: {model_type}") raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio": elif modality == "audio":
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set,
Tuple, TypedDict, Union)
import torch
from torch import nn
from transformers import BatchFeature, Gemma3Config, ProcessorMixin
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
logger = init_logger(__name__)
class Gemma3ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
Gemma3ImageInputs = Gemma3ImagePixelInputs
class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
hf_config = self.ctx.get_hf_config()
return {"image": hf_config.mm_tokens_per_image}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[ProcessorMixin],
) -> int:
hf_config = self.ctx.get_hf_config()
return hf_config.mm_tokens_per_image
def get_image_size_with_most_features(self) -> ImageSize:
# Result in the max possible feature size (h:w = 16:1)
return ImageSize(height=8000, width=50)
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer()
boi_token = tokenizer.boi_token
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text=" ".join([boi_token] * num_images),
mm_data=mm_data,
)
class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
# TODO(woosuk): Support pan-and-scan.
img_kwargs = mm_kwargs.get("images_kwargs", {})
img_kwargs["do_pan_and_scan"] = False
mm_kwargs["images_kwargs"] = img_kwargs
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer()
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_config = self.info.get_hf_config()
boi_token = tokenizer.boi_token
image_token = tokenizer.image_token
mm_tokens_per_image = hf_config.mm_tokens_per_image
image_tokens_expanded = "".join([image_token] * mm_tokens_per_image)
def get_replacement_gemma3(item_idx: int):
return PromptUpdateDetails(
full=hf_processor.full_image_sequence,
features=image_tokens_expanded,
)
return [
PromptReplacement(
modality="image",
target=boi_token,
replacement=get_replacement_gemma3,
)
]
class Gemma3MultiModalProjector(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.mm_input_projection_weight = nn.Parameter(
torch.zeros(config.vision_config.hidden_size,
config.text_config.hidden_size))
self.mm_soft_emb_norm = GemmaRMSNorm(
config.vision_config.hidden_size,
eps=config.vision_config.layer_norm_eps)
self.patches_per_image = int(config.vision_config.image_size //
config.vision_config.patch_size)
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
self.kernel_size = self.patches_per_image // self.tokens_per_side
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size,
stride=self.kernel_size)
def forward(self, vision_outputs: torch.Tensor):
batch_size, _, seq_length = vision_outputs.shape
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
batch_size, seq_length, self.patches_per_image,
self.patches_per_image)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
projected_vision_outputs = torch.matmul(
normed_vision_outputs, self.mm_input_projection_weight)
return projected_vision_outputs.type_as(vision_outputs)
@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor,
info=Gemma3ProcessingInfo,
dummy_inputs=Gemma3DummyInputsBuilder)
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.sliding_window = config.text_config.interleaved_sliding_window
self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config,
prefix=maybe_prefix(
prefix, "vision_tower"))
self.multi_modal_projector = Gemma3MultiModalProjector(config)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Gemma3ForCausalLM"],
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.scale *= logit_scale
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@property
def sampler(self):
return self.language_model.sampler
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
if d.shape != expected_dims:
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_dims}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
return None
if not isinstance(pixel_values, (torch.Tensor, list[torch.Tensor])):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_values = flatten_bn(pixel_values, concat=True)
return Gemma3ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
def _image_pixels_to_features(
self,
vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor,
) -> torch.Tensor:
target_dtype = vision_tower.get_input_embeddings().weight.dtype
image_features = vision_tower(pixel_values.to(dtype=target_dtype))
return image_features
def _process_image_input(
self,
image_input: Gemma3ImageInputs,
) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = image_input["data"]
vision_outputs = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)
return self.multi_modal_projector(vision_outputs)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
if multimodal_embeddings is None:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
else:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index)
return inputs_embeds
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object) -> Union[SamplerOutput, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
if vision_embeddings is not None:
kwargs = self.prepare_attn_masks(
input_ids,
positions,
mask_dtype=vision_embeddings.dtype,
**kwargs)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs)
return hidden_states
def prepare_attn_masks(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mask_dtype: torch.dtype,
**kwargs,
):
kwargs["has_images"] = True
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
# This is a HACK. Fix this.
start_idices = (positions == 0).cpu().nonzero()
num_seqs = len(start_idices)
seq_lens = []
for i in range(num_seqs):
start_idx = start_idices[i].item()
if i < num_seqs - 1:
end_idx = start_idices[i + 1].item()
else:
end_idx = len(input_ids)
seq_lens.append(end_idx - start_idx)
kwargs["seq_lens"] = seq_lens
global_attn_masks = []
local_attn_masks = []
start_idx = 0
for seq_len in seq_lens:
end_idx = start_idx + seq_len
input_token_ids = input_ids[start_idx:end_idx]
start_idx = end_idx
# Create a global causal mask.
global_attn_mask = torch.empty(
1,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
# Fill the lower triangle with 0.
global_attn_mask = global_attn_mask.triu(diagonal=1)
# Consider the bidirectional attention between image tokens.
img_mask = torch.zeros_like(global_attn_mask)
img_pos = (input_token_ids == self.config.image_token_index)
img_mask[:, :, :, img_pos] += 1
img_mask[:, :, img_pos, :] += 1
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)
# Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask,
diagonal=-self.sliding_window)
local_attn_mask = torch.where(local_attn_mask == 0,
global_attn_mask, float("-inf"))
local_attn_masks.append(local_attn_mask)
kwargs["global_attn_masks"] = global_attn_masks
kwargs["local_attn_masks"] = local_attn_masks
return kwargs
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
...@@ -53,6 +53,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -53,6 +53,7 @@ _TEXT_GENERATION_MODELS = {
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"), "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
...@@ -161,6 +162,7 @@ _MULTIMODAL_MODELS = { ...@@ -161,6 +162,7 @@ _MULTIMODAL_MODELS = {
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
"InternVLChatModel": ("internvl", "InternVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment