"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "69e72b1dd113927ed638f26e82738e9735385edc"
Unverified Commit 9b00990b authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

chore: remove vlm unnecessary import (#7541)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avatarMick <mickjagger19@icloud.com>
parent 4d67025a
...@@ -565,6 +565,7 @@ multimodal_model_archs = [ ...@@ -565,6 +565,7 @@ multimodal_model_archs = [
"CLIPModel", "CLIPModel",
"DeepseekVL2ForCausalLM", "DeepseekVL2ForCausalLM",
"Gemma3ForConditionalGeneration", "Gemma3ForConditionalGeneration",
"Gemma3nForConditionalGeneration",
"Grok1VForCausalLM", "Grok1VForCausalLM",
"Grok1AForCausalLM", "Grok1AForCausalLM",
"LlavaLlamaForCausalLM", "LlavaLlamaForCausalLM",
......
...@@ -823,6 +823,7 @@ register_conv_template( ...@@ -823,6 +823,7 @@ register_conv_template(
sep_style=SeparatorStyle.GEMMA3, sep_style=SeparatorStyle.GEMMA3,
stop_str=["<end_of_turn>"], stop_str=["<end_of_turn>"],
image_token="<start_of_image>", image_token="<start_of_image>",
audio_token="<start_of_audio>",
) )
) )
......
...@@ -23,6 +23,7 @@ class MultimodalInputFormat(Enum): ...@@ -23,6 +23,7 @@ class MultimodalInputFormat(Enum):
RAW_IMAGES = "raw_images" RAW_IMAGES = "raw_images"
PRECOMPUTED_FEATURES = "precomputed_features" PRECOMPUTED_FEATURES = "precomputed_features"
PIXEL_VALUES = "pixel_values" PIXEL_VALUES = "pixel_values"
AUDIO = "audio"
@dataclasses.dataclass @dataclasses.dataclass
...@@ -441,10 +442,13 @@ class BaseMultimodalProcessor(ABC): ...@@ -441,10 +442,13 @@ class BaseMultimodalProcessor(ABC):
has_image = False has_image = False
has_pixel_values = False has_pixel_values = False
has_precomputed_features = False has_precomputed_features = False
has_audio = False
for mm_input in mm_inputs: for mm_input in mm_inputs:
if isinstance(mm_input, Image.Image): if isinstance(mm_input, Image.Image):
has_image = True has_image = True
elif isinstance(mm_input, np.ndarray):
has_audio = True
elif isinstance(mm_input, dict): elif isinstance(mm_input, dict):
if mm_input.get("precomputed_features", None) is not None: if mm_input.get("precomputed_features", None) is not None:
has_precomputed_features = True has_precomputed_features = True
...@@ -461,13 +465,13 @@ class BaseMultimodalProcessor(ABC): ...@@ -461,13 +465,13 @@ class BaseMultimodalProcessor(ABC):
# Validate format consistency # Validate format consistency
format_count = sum( format_count = sum(
[has_image, has_pixel_values, has_precomputed_features] [has_image, has_pixel_values, has_precomputed_features, has_audio]
) )
if format_count > 1: if format_count > 1:
raise ValueError( raise ValueError(
"Unsupported: mixture of multimodal input formats. " "Unsupported: mixture of multimodal input formats. "
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, " f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
f"precomputed_features={has_precomputed_features}" f"precomputed_features={has_precomputed_features}, audio={has_audio}"
) )
if has_image: if has_image:
...@@ -476,6 +480,8 @@ class BaseMultimodalProcessor(ABC): ...@@ -476,6 +480,8 @@ class BaseMultimodalProcessor(ABC):
return MultimodalInputFormat.PRECOMPUTED_FEATURES return MultimodalInputFormat.PRECOMPUTED_FEATURES
elif has_pixel_values: elif has_pixel_values:
return MultimodalInputFormat.PIXEL_VALUES return MultimodalInputFormat.PIXEL_VALUES
elif has_audio:
return MultimodalInputFormat.AUDIO
else: else:
raise ValueError("No valid multimodal input format found") raise ValueError("No valid multimodal input format found")
except Exception as e: except Exception as e:
...@@ -521,20 +527,47 @@ class BaseMultimodalProcessor(ABC): ...@@ -521,20 +527,47 @@ class BaseMultimodalProcessor(ABC):
input_ids = tokenize_text(base_output.input_text) input_ids = tokenize_text(base_output.input_text)
return combined_mm_item, input_ids return combined_mm_item, input_ids
def process_audio(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with audio."""
ret = self.process_mm_data(
input_text=base_output.input_text,
audio=base_output.audios, # Note: "audio" is for gemma3n only
)
combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO)
for key, value in ret.items():
if key != "input_ids" and hasattr(combined_mm_item, key):
setattr(combined_mm_item, key, value)
input_ids = ret["input_ids"].flatten()
return combined_mm_item, input_ids
def finalize_mm_item( def finalize_mm_item(
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
) -> MultimodalDataItem: ) -> MultimodalDataItem:
"""Apply common post-processing to the multimodal item.""" """Apply common post-processing to the multimodal item."""
combined_mm_item.image_offsets = self.get_mm_items_offset( if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
input_ids=input_ids, combined_mm_item.image_offsets = self.get_mm_items_offset(
mm_token_id=self.IM_TOKEN_ID, input_ids=input_ids,
) mm_token_id=self.IM_TOKEN_ID,
)
elif combined_mm_item.modality == Modality.AUDIO:
combined_mm_item.audio_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.AUDIO_TOKEN_ID,
)
elif combined_mm_item.modality == Modality.VIDEO:
combined_mm_item.video_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.VIDEO_TOKEN_ID,
)
else:
raise ValueError(f"Unknown modality: {combined_mm_item.modality}")
return combined_mm_item return combined_mm_item
# Main logic # Main logic - determine input type and handle text-only case
mm_inputs = base_output.images mm_inputs = base_output.images or base_output.audios
if not mm_inputs: if not mm_inputs:
# Return text-only case
input_ids = tokenize_text(base_output.input_text) input_ids = tokenize_text(base_output.input_text)
return None, input_ids return None, input_ids
...@@ -548,6 +581,8 @@ class BaseMultimodalProcessor(ABC): ...@@ -548,6 +581,8 @@ class BaseMultimodalProcessor(ABC):
combined_mm_item, input_ids = process_precomputed_features(base_output) combined_mm_item, input_ids = process_precomputed_features(base_output)
elif input_format == MultimodalInputFormat.PIXEL_VALUES: elif input_format == MultimodalInputFormat.PIXEL_VALUES:
combined_mm_item, input_ids = process_pixel_values(base_output) combined_mm_item, input_ids = process_pixel_values(base_output)
elif input_format == MultimodalInputFormat.AUDIO:
combined_mm_item, input_ids = process_audio(base_output)
else: else:
raise ValueError(f"Unknown input format: {input_format}") raise ValueError(f"Unknown input format: {input_format}")
......
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import re
from typing import Dict, List, Optional, Union
from sglang.srt.managers.multimodal_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
)
from sglang.srt.models.gemma3n_mm import Gemma3nForConditionalGeneration
class Gemma3nSGLangProcessor(SGLangBaseProcessor):
"""Multimodal processor for Gemma3n supporting image and audio inputs."""
models = [Gemma3nForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<image_soft_token>"
self.IMAGE_TOKEN_REGEX = re.compile(
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
)
self.AUDIO_TOKEN = "<audio_soft_token>"
self.AUDIO_TOKEN_REGEX = re.compile(
r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
)
self.IM_TOKEN_ID = hf_config.image_token_id
self.IM_START_TOKEN_ID = hf_config.boi_token_id
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
self.AUDIO_TOKEN_ID = hf_config.audio_token_id
self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id
self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id
async def process_mm_data_async(
self,
image_data: Optional[List[Union[str, bytes, Dict]]] = None,
audio_data: Optional[List[Union[str, bytes, Dict]]] = None,
input_text: str = "",
request_obj=None,
max_req_input_len: int = 0,
*args,
**kwargs,
):
"""Process multimodal data including images and audio."""
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(audio_data, str):
audio_data = [audio_data]
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
audio_data=audio_data,
max_req_input_len=max_req_input_len,
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN,
image_token_regex=self.IMAGE_TOKEN_REGEX,
audio_token=self.AUDIO_TOKEN,
audio_token_regex=self.AUDIO_TOKEN_REGEX,
),
)
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
return {
"input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
"audio_start_id": self.AUDIO_START_TOKEN_ID,
"audio_end_id": self.AUDIO_END_TOKEN_ID,
}
...@@ -214,6 +214,10 @@ class MultimodalDataItem: ...@@ -214,6 +214,10 @@ class MultimodalDataItem:
audio_feature_lens: Optional[List[torch.Tensor]] = None audio_feature_lens: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None audio_offsets: Optional[List[Tuple[int, int]]] = None
# gemma3n related
input_features: Optional[torch.Tensor] = None
input_features_mask: Optional[torch.Tensor] = None
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
@staticmethod @staticmethod
...@@ -277,7 +281,10 @@ class MultimodalDataItem: ...@@ -277,7 +281,10 @@ class MultimodalDataItem:
if self.precomputed_features is not None: if self.precomputed_features is not None:
self.hash = hash_feature(self.precomputed_features) self.hash = hash_feature(self.precomputed_features)
elif self.is_audio(): elif self.is_audio():
self.hash = hash_feature(self.audio_features) if self.audio_features is not None:
self.hash = hash_feature(self.audio_features)
elif self.input_features is not None:
self.hash = hash_feature(self.input_features)
else: else:
self.hash = hash_feature(self.pixel_values) self.hash = hash_feature(self.pixel_values)
...@@ -288,6 +295,7 @@ class MultimodalDataItem: ...@@ -288,6 +295,7 @@ class MultimodalDataItem:
return (self.modality == Modality.AUDIO) and ( return (self.modality == Modality.AUDIO) and (
self.precomputed_features is not None self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.audio_features) or not MultimodalDataItem.is_empty_list(self.audio_features)
or not MultimodalDataItem.is_empty_list(self.input_features)
) )
def is_image(self): def is_image(self):
......
This diff is collapsed.
This diff is collapsed.
import logging
import re
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict, Union
import torch
from torch import nn
from transformers import (
Gemma3nAudioConfig,
Gemma3nConfig,
Gemma3nTextConfig,
Gemma3nVisionConfig,
PreTrainedModel,
)
from transformers.models.auto.modeling_auto import AutoModel
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
MultimodalDataItem,
MultimodalInputs,
flatten_nested_list,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.models.gemma3n_audio import Gemma3nAudioEncoder
from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm, Gemma3nTextModel
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
cached_get_processor = lru_cache(get_processor)
class Gemma3nImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class Gemma3nAudioInputs(TypedDict):
input_features: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
input_features_mask: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length)`"""
class Gemma3nMultimodalEmbedder(nn.Module):
"""Embeds token ids or soft tokens for multimodal content into language model space."""
def __init__(
self,
multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
text_config: Gemma3nTextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.multimodal_hidden_size = multimodal_config.hidden_size
self.eps = multimodal_config.rms_norm_eps
self.vocab_offset = multimodal_config.vocab_offset
self.vocab_size = multimodal_config.vocab_size
self.text_hidden_size = text_config.hidden_size
self.embedding = VocabParallelEmbedding(
self.vocab_size,
self.multimodal_hidden_size,
quant_config=quant_config,
prefix=add_prefix("embedding", prefix),
)
self.hard_embedding_norm = Gemma3nRMSNorm(
self.multimodal_hidden_size,
eps=self.eps,
)
self.soft_embedding_norm = Gemma3nRMSNorm(
self.multimodal_hidden_size,
eps=self.eps,
)
self.embedding_projection = RowParallelLinear(
self.multimodal_hidden_size,
self.text_hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("embedding_projection", prefix),
)
self.embedding_post_projection_norm = Gemma3nRMSNorm(
self.text_hidden_size,
eps=self.eps,
with_scale=False,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Embeds token ids or soft tokens for multimodal content into language model space.
Args:
input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
`[vocab_offset, vocab_offset + vocab_size)`.
inputs_embeds: A torch.Tensor containing the soft tokens to embed.
Returns:
A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if inputs_embeds is not None:
emb_norm = self.soft_embedding_norm(inputs_embeds)
else:
# Handle out of vocab ids to prevent CUDA assertion failures
out_of_vocab_id = self.vocab_size - 1
adjusted_ids = input_ids - self.vocab_offset
adjusted_ids = torch.where(adjusted_ids < 0, out_of_vocab_id, adjusted_ids)
adjusted_ids = torch.where(
adjusted_ids >= self.vocab_size, out_of_vocab_id, adjusted_ids
)
hard_emb = self.embedding(adjusted_ids)
emb_norm = self.hard_embedding_norm(hard_emb)
emb_norm_proj, _ = self.embedding_projection(emb_norm)
return self.embedding_post_projection_norm(emb_norm_proj)
class Gemma3nForConditionalGeneration(PreTrainedModel):
config_class = Gemma3nConfig
"""Gemma3n multimodal model for conditional generation."""
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
".out_proj.",
]
bitsandbytes_stacked_params_mapping = {
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
"out_proj": ("proj", 0),
}
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer
embedding_modules = {}
embedding_padding_modules = []
supports_lora = True
def __init__(
self,
config: Gemma3nConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config)
self.config = config
self.quant_config = quant_config
prefix = add_prefix("model", prefix)
# Vision components
# TODO: Use sglang's vision model
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.embed_vision = Gemma3nMultimodalEmbedder(
config.vision_config,
config.text_config,
quant_config=quant_config,
prefix=add_prefix("embed_vision", prefix),
)
# Audio components
self.embed_audio = Gemma3nMultimodalEmbedder(
config.audio_config,
config.text_config,
quant_config=quant_config,
prefix=add_prefix("embed_audio", prefix),
)
self.audio_tower = Gemma3nAudioEncoder(
config.audio_config,
quant_config=quant_config,
prefix=add_prefix("audio_tower", prefix),
)
self.vocab_size = config.text_config.vocab_size
self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
# Text model
self.language_model = Gemma3nTextModel(
config.text_config,
quant_config,
prefix=add_prefix("language_model", prefix),
)
# Create logits processor for the multimodal model
self.logits_processor = LogitsProcessor(config.text_config)
self.post_init()
def pad_input_ids(
self,
input_ids: List[int],
mm_inputs: Optional[MultimodalInputs] = None,
) -> List[int]:
"""Pad input IDs with image and audio tokens."""
if mm_inputs is None:
return input_ids
# Collect available media token pairs
media_token_pairs = []
for attr_name in ["im_start_id", "audio_start_id"]:
if hasattr(mm_inputs, attr_name):
start_id = getattr(mm_inputs, attr_name)
end_id = getattr(mm_inputs, attr_name.replace("start", "end"))
media_token_pairs.append((start_id, end_id))
# Apply padding pattern if we have media tokens
if media_token_pairs:
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, mm_inputs)
return input_ids
def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.get_input_embeddings()
def get_attention_sliding_window_size(self):
return self.config.text_config.sliding_window - 1
def get_image_feature(self, items: List[MultimodalDataItem]):
"""
Projects the last hidden state from the vision model into language model space.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
# Process images one by one to handle flatten_batch=True constraint in vision_tower
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
vision_outputs_list = []
for pixel_values_batch in all_pixel_values:
# Normalize input shape to [batch_size, channels, height, width]
if pixel_values_batch.dim() == 5:
pixel_values_batch = pixel_values_batch.squeeze(0)
elif pixel_values_batch.dim() == 3:
pixel_values_batch = pixel_values_batch.unsqueeze(0)
elif pixel_values_batch.dim() != 4:
raise ValueError(
f"Unexpected pixel_values shape: {pixel_values_batch.shape}"
)
# Process each image in the batch
batch_size = pixel_values_batch.shape[0]
for i in range(batch_size):
pixel_value = pixel_values_batch[i : i + 1] # Keep batch dimension as 1
pixel_value = pixel_value.to(
device=self.vision_tower.device, dtype=self.language_model.dtype()
)
vision_outputs = self.vision_tower(
pixel_values=pixel_value, do_pooling=False, return_dict=True
).last_hidden_state
vision_outputs_list.append(vision_outputs)
# Concatenate all vision outputs
vision_outputs = torch.cat(vision_outputs_list, dim=0)
# Convert from (batch, channels, height, width) to (batch, height * width, channels)
vision_outputs = vision_outputs.reshape(
vision_outputs.shape[0],
self.config.vision_config.hidden_size,
self.config.vision_soft_tokens_per_image,
).permute(0, 2, 1)
# Normalize and embed the soft tokens into language model space
vision_outputs *= self.config.vision_config.hidden_size**0.5
return self.embed_vision(inputs_embeds=vision_outputs)
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
"""
Projects the last hidden state from the audio encoder into language model space.
Args:
items: List of multimodal data items containing audio data.
Returns:
audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_audios, audio_length, embed_dim)`).
"""
# Extract audio features and masks from items
all_input_features = flatten_nested_list(
[item.input_features for item in items]
)
all_input_features_mask = flatten_nested_list(
[~item.input_features_mask for item in items]
) # Note(Xinyuan): reverse the mask according to the HF implementation
# Process audio features one by one
audio_features_list = []
for input_features, input_features_mask in zip(
all_input_features, all_input_features_mask
):
# Ensure proper tensor format
if input_features.dim() == 2:
input_features = input_features.unsqueeze(0)
if input_features_mask.dim() == 1:
input_features_mask = input_features_mask.unsqueeze(0)
# Move to device and dtype
input_features = input_features.to(
device=next(self.audio_tower.parameters()).device,
dtype=self.language_model.dtype(),
)
input_features_mask = input_features_mask.to(device=input_features.device)
# Process through audio tower
audio_outputs, audio_mask = self.audio_tower(
input_features, input_features_mask
)
# Embed the audio outputs
audio_embeds = self.embed_audio(inputs_embeds=audio_outputs)
audio_features_list.append(audio_embeds)
# Concatenate all audio features
if audio_features_list:
audio_features = torch.cat(audio_features_list, dim=0)
# The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
# text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
# produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
# the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
audio_padding_toks = torch.tensor(
[[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device
)
audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
audio_features = torch.where(
audio_mask.unsqueeze(-1), audio_padding_embs, audio_features
)
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
extra_padding_tokens = (
self.config.audio_soft_tokens_per_image - audio_seq_len
)
extra_padding_features = audio_padding_embs.expand(
audio_batch_size, extra_padding_tokens, audio_embed_dim
)
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
return audio_features
else:
return torch.empty(
0,
0,
self.language_model.config.hidden_size,
device=next(self.parameters()).device,
dtype=self.language_model.dtype(),
)
def get_per_layer_inputs(
self, input_ids: torch.LongTensor
) -> Optional[torch.Tensor]:
return self.language_model.get_per_layer_inputs(input_ids)
def project_per_layer_inputs(
self,
inputs_embeds: torch.Tensor,
per_layer_inputs: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.language_model.project_per_layer_inputs(
inputs_embeds, per_layer_inputs
)
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
**kwargs: object,
) -> LogitsProcessor:
"""Forward pass for multimodal Gemma3n."""
if (input_ids is None) ^ (input_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
positions += 1
if input_ids is not None:
# Prepare per-layer inputs from inputs_ids
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0, input_ids < self.vocab_size_per_layer_input
)
per_layer_inputs_tokens = torch.where(
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
)
per_layer_inputs = self.language_model.get_per_layer_inputs(
per_layer_inputs_tokens
)
# Use general_mm_embed_routine for handling multimodal data
# This will automatically handle text, image, and audio embeddings
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
audio_data_embedding_func=self.get_audio_feature,
positions=positions,
per_layer_inputs=per_layer_inputs,
)
# Process hidden states through logits processor
return self.logits_processor(
input_ids, hidden_states, self.language_model.embed_tokens, forward_batch
)
def tie_weights(self):
return self.language_model.tie_weights()
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".up_proj", 1),
(".gate_up_proj", ".gate_proj", 0),
]
"""Load weights for the model."""
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
name = re.sub(r"^model\.", "", name)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "vision_model" in name:
# adapt to VisionAttention
name = name.replace(".self_attn.out_proj", ".self_attn.proj")
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
EntryClass = Gemma3nForConditionalGeneration
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment