# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/siglip2/modular_siglip2.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_siglip2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import warnings from dataclasses import dataclass from typing import Any, Callable, Optional, Union, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.init import _calculate_fan_in_and_fan_out from flash_attn import flash_attn_varlen_func from flash_attn.layers.rotary import apply_rotary_emb from transformers.activations import ACT2FN # from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_outputs import BaseModelOutputWithNoAttention from transformers.modeling_utils import PreTrainedModel from .configuration_siglip2_navit import Siglip2NavitConfig __all__ = ["Siglip2NavitModel"] # copied from qwen2.5-vl class VisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) freqs = torch.outer(seq, self.inv_freq) return freqs class Siglip2VisionEmbeddings(nn.Module): def __init__(self, config: Siglip2NavitConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.patch_size = config.patch_size self.image_size = config.image_size self.num_patches = config.num_patches self.preserve_original_pe = config.preserve_original_pe self.hidden_stride = config.hidden_stride # siglip2 naflex if self.num_patches > 0: self.patch_embedding = nn.Linear( in_features=config.num_channels * self.patch_size * self.patch_size, out_features=self.embed_dim, ) if self.preserve_original_pe: self.position_embedding_size = int(self.num_patches**0.5) self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) else: self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", ) if self.preserve_original_pe: self.num_patches = (self.image_size // self.patch_size) ** 2 self.position_embedding_size = self.image_size // self.patch_size self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) @staticmethod def resize_positional_embeddings( positional_embeddings: torch.Tensor, spatial_shapes: torch.LongTensor, max_length: int, ) -> torch.Tensor: """ Resize positional embeddings to image-specific size and pad to a fixed size. Args: positional_embeddings (`torch.Tensor`): Position embeddings of shape (height, width, embed_dim) spatial_shapes (`torch.LongTensor`): Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to max_length (`int`): Maximum length of the positional embeddings to pad resized positional embeddings to Returns: `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) """ batch_size = spatial_shapes.shape[0] embed_dim = positional_embeddings.shape[-1] source_dtype = positional_embeddings.dtype resulted_positional_embeddings = torch.empty( (batch_size, max_length, embed_dim), device=positional_embeddings.device, dtype=source_dtype, ) # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0) # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU if positional_embeddings.device.type == "cpu": positional_embeddings = positional_embeddings.to(torch.float32) for i in range(batch_size): # (1, dim, height, width) -> (1, dim, target_height, target_width) height, width = spatial_shapes[i] resized_embeddings = F.interpolate( positional_embeddings, size=(height, width), mode="bilinear", align_corners=False, antialias=True, ) # (1, dim, target_height, target_width) -> (target_height * target_width, dim) resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1) # Cast to original dtype resized_embeddings = resized_embeddings.to(source_dtype) resulted_positional_embeddings[i, : height * width] = resized_embeddings resulted_positional_embeddings[i, height * width :] = resized_embeddings[0] return resulted_positional_embeddings def forward(self, pixel_values: torch.FloatTensor, grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor: """ Args: pixel_values (`torch.FloatTensor`): Pixel values of shape (num_patches, num_channels * temporal_patch_size * patch_size * patch_size) grid_thws: (`torch.LongTensor`): grid shape (num_patches, 3) """ # Apply patch embeddings to already patchified pixel values target_dtype = self.patch_embedding.weight.dtype if isinstance(self.patch_embedding, nn.Linear): patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) elif isinstance(self.patch_embedding, nn.Conv2d): pixel_values = pixel_values.view(-1, self.config.num_channels * self.config.temporal_patch_size, self.patch_size, self.patch_size) patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) patch_embeds = patch_embeds.reshape(-1, self.embed_dim) if self.preserve_original_pe: assert grid_thws is not None pos_embed_new = torch.zeros_like(patch_embeds) ori_h = ori_w = self.position_embedding_size positional_embeddings = self.position_embedding.weight.reshape( self.position_embedding_size, self.position_embedding_size, -1 ).unsqueeze(0).permute(0,3,1,2) # pos_embed = self.pos_embed.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2) cnt = 0 for t, h, w in grid_thws: thw = t * h * w pe = F.interpolate(positional_embeddings, size=(h, w), mode='bicubic', align_corners=False) pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1) pe = pe[0].repeat(t, 1) pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride, w // self.hidden_stride, self.hidden_stride, -1) pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(thw, -1) pos_embed_new[cnt:cnt + thw] = pe cnt += thw patch_embeds = patch_embeds + pos_embed_new return patch_embeds def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights # copied from qwen2.5-vl def apply_rotary_pos_emb_flashatt( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) return q_embed, k_embed class Siglip2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) self.use_rope = config.use_rope def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" seq_length, embed_dim = hidden_states.shape queries = self.q_proj(hidden_states) keys = self.k_proj(hidden_states) values = self.v_proj(hidden_states) queries = queries.view(seq_length, self.num_heads, self.head_dim) keys = keys.view(seq_length, self.num_heads, self.head_dim) values = values.view(seq_length, self.num_heads, self.head_dim) if self.use_rope: cos, sin = position_embeddings queries, keys = apply_rotary_pos_emb_flashatt(queries.unsqueeze(0), keys.unsqueeze(0), cos, sin) queries = queries.squeeze(0) keys = keys.squeeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = flash_attn_varlen_func(queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( seq_length, -1 ) attn_output = self.out_proj(attn_output) return attn_output class Siglip2MLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class Siglip2EncoderLayer(nn.Module): def __init__(self, config: Siglip2NavitConfig): super().__init__() self.embed_dim = config.hidden_size self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.self_attn = Siglip2Attention(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Siglip2MLP(config) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, position_embeddings: torch.Tensor ) -> tuple[torch.FloatTensor]: """ Args: hidden_states (`torch.FloatTensor`): Input to the layer of shape `(batch, seq_len, embed_dim)`. attention_mask (`torch.FloatTensor`): Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class Siglip2Encoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`Siglip2EncoderLayer`]. Args: config: Siglip2NavitConfig """ def __init__(self, config: Siglip2NavitConfig): super().__init__() self.config = config self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self.rotary_pos_emb = VisionRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2) self.patch_size = config.patch_size self.hidden_stride = config.hidden_stride self.window_size = config.window_size self.spatial_merge_unit = config.hidden_stride * config.hidden_stride self.fullatt_block_indexes = None if config.fullatt_block_indexes is None else [int(i) for i in config.fullatt_block_indexes.split('|')] # copied from qwen2.5_vl def rot_pos_emb(self, grid_thw): pos_ids = [] for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = hpos_ids.reshape( h // self.hidden_stride, self.hidden_stride, w // self.hidden_stride, self.hidden_stride, ) hpos_ids = hpos_ids.permute(0, 2, 1, 3) hpos_ids = hpos_ids.flatten() wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) wpos_ids = wpos_ids.reshape( h // self.hidden_stride, self.hidden_stride, w // self.hidden_stride, self.hidden_stride, ) wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.flatten() pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb def get_window_index(self, grid_thw): window_index: list = [] cu_window_seqlens: list = [0] window_index_id = 0 vit_merger_window_size = self.window_size // self.hidden_stride // self.patch_size # patch (after merge) number in each window for grid_t, grid_h, grid_w in grid_thw: llm_grid_h, llm_grid_w = ( grid_h // self.hidden_stride, # number of patch after merge grid_w // self.hidden_stride, ) index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) index_padded = index_padded.reshape( grid_t, num_windows_h, vit_merger_window_size, num_windows_w, vit_merger_window_size, ) index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( grid_t, num_windows_h * num_windows_w, vit_merger_window_size, vit_merger_window_size, ) seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) index_padded = index_padded.reshape(-1) index_new = index_padded[index_padded != -100] window_index.append(index_new + window_index_id) cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() window_index = torch.cat(window_index, dim=0) return window_index, cu_window_seqlens def forward( self, inputs_embeds, grid_thws: torch.Tensor, output_hidden_states: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ rotary_pos_emb = self.rot_pos_emb(grid_thws) window_index, cu_window_seqlens = self.get_window_index(grid_thws) cu_window_seqlens = torch.tensor( cu_window_seqlens, device=inputs_embeds.device, dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32, ) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) seq_len, _ = inputs_embeds.size() inputs_embeds = inputs_embeds.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) inputs_embeds = inputs_embeds[window_index, :, :] inputs_embeds = inputs_embeds.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) cu_seqlens = torch.repeat_interleave(grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]).cumsum( dim=0, # Select dtype based on the following factors: # - FA2 requires that cu_seqlens_q must have dtype int32 # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw # See https://github.com/huggingface/transformers/pull/34852 for more information dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) reverse_indices = torch.argsort(window_index) encoder_states = () if output_hidden_states else None hidden_states = inputs_embeds for index, block in enumerate(self.layers): if self.fullatt_block_indexes is None or index in self.fullatt_block_indexes: cu_seqlens_tmp = cu_seqlens else: cu_seqlens_tmp = cu_window_seqlens if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func(block.__call__, hidden_states, cu_seqlens_tmp, position_embeddings) else: hidden_states = block(hidden_states, cu_seqlens_tmp, position_embeddings) if output_hidden_states: hidden_states_ = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) encoder_states += (hidden_states_[reverse_indices, :].reshape(seq_len, -1),) # tokens = self.post_trunk_norm(tokens) hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1) return hidden_states, encoder_states class Siglip2VisionTransformer(nn.Module): def __init__(self, config: Siglip2NavitConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = Siglip2VisionEmbeddings(config) self.encoder = Siglip2Encoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, pixel_values: torch.FloatTensor, grid_thws: torch.LongTensor, output_hidden_states: Optional[bool] = True, return_dict: Optional[bool] = True, ) -> Union[ Tuple[torch.Tensor], Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], BaseModelOutputWithNoAttention, ]: r""" spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): Tensor containing the spatial dimensions (height, width) of the input images. """ # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # output_hidden_states = ( # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # ) hidden_states = self.embeddings(pixel_values, grid_thws) last_hidden_state, hidden_states = self.encoder(hidden_states, grid_thws, output_hidden_states) last_hidden_state = self.post_layernorm(last_hidden_state) if not return_dict: output = (last_hidden_state,) output += (hidden_states,) if output_hidden_states else () return output return BaseModelOutputWithNoAttention( last_hidden_state=last_hidden_state, hidden_states=hidden_states ) def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2, ) # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) def trunc_normal_tf_( tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 ) -> torch.Tensor: """Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \\leq \text{mean} \\leq b`. NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 and the result is subsequently scaled and shifted by the mean and std args. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value """ with torch.no_grad(): _trunc_normal_(tensor, 0, 1.0, a, b) tensor.mul_(std).add_(mean) def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": denom = fan_in elif mode == "fan_out": denom = fan_out elif mode == "fan_avg": denom = (fan_in + fan_out) / 2 variance = scale / denom if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) elif distribution == "normal": with torch.no_grad(): tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) with torch.no_grad(): tensor.uniform_(-bound, bound) else: raise ValueError(f"invalid distribution {distribution}") def lecun_normal_(tensor): variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") def default_flax_embed_init(tensor): variance_scaling_(tensor, mode="fan_in", distribution="normal") class Siglip2PreTrainedModel(PreTrainedModel): config_class = Siglip2NavitConfig base_model_prefix = "siglip2_navit" supports_gradient_checkpointing = True _no_split_modules = [ "Siglip2VisionEmbeddings", "Siglip2EncoderLayer", ] _supports_flash_attn_2 = True _supports_sdpa = False _supports_flex_attn = False _supports_attention_backend = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, Siglip2VisionEmbeddings): width = ( self.config.hidden_size if isinstance(self.config, Siglip2NavitConfig) else self.config.hidden_size ) nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) elif isinstance(module, nn.Embedding): default_flax_embed_init(module.weight) elif isinstance(module, Siglip2Attention): nn.init.xavier_uniform_(module.q_proj.weight) nn.init.xavier_uniform_(module.k_proj.weight) nn.init.xavier_uniform_(module.v_proj.weight) nn.init.xavier_uniform_(module.out_proj.weight) nn.init.zeros_(module.q_proj.bias) nn.init.zeros_(module.k_proj.bias) nn.init.zeros_(module.v_proj.bias) nn.init.zeros_(module.out_proj.bias) elif isinstance(module, Siglip2MLP): nn.init.xavier_uniform_(module.fc1.weight) nn.init.xavier_uniform_(module.fc2.weight) nn.init.normal_(module.fc1.bias, std=1e-6) nn.init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, Siglip2MultiheadAttentionPoolingHead): nn.init.xavier_uniform_(module.probe.data) nn.init.xavier_uniform_(module.attention.in_proj_weight.data) nn.init.zeros_(module.attention.in_proj_bias.data) elif isinstance(module, Siglip2Model): logit_scale_init = torch.log(torch.tensor(1.0)) module.logit_scale.data.fill_(logit_scale_init) module.logit_bias.data.zero_() elif isinstance(module, Siglip2ForImageClassification): nn.init.normal_( module.classifier.weight, std=self.config.hidden_size**-0.5 * self.config.initializer_factor, ) elif isinstance(module, (nn.Linear, nn.Conv2d)): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class Siglip2NavitModel(Siglip2PreTrainedModel): config_class = Siglip2NavitConfig main_input_name = "pixel_values" def __init__(self, config: Siglip2NavitConfig): super().__init__(config) self.vision_model = Siglip2VisionTransformer(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding def forward( self, pixel_values: torch.FloatTensor, grid_thws: torch.LongTensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[ Tuple[torch.Tensor], Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], BaseModelOutputWithNoAttention, ]: r""" pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): Mask to avoid performing attention on padding pixel indices. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): Tensor containing the spatial dimensions (height, width) of the input images. Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, Siglip2VisionModel >>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224") >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled features ```""" if output_hidden_states is None: output_hidden_states = self.config.output_hidden_states if return_dict is None: return_dict = self.config.use_return_dict return self.vision_model( pixel_values=pixel_values, grid_thws=grid_thws, output_hidden_states=output_hidden_states, return_dict=return_dict, ) def _get_block(self, layer_index): return self.vision_model.encoder.layers[layer_index] def _get_attn_weight(self, layer_index): return torch.cat([self._get_block(layer_index).self_attn.q_proj.weight, self._get_block(layer_index).self_attn.k_proj.weight, self._get_block(layer_index).self_attn.v_proj.weight]) def _get_pose_embed(self): return self.vision_model.embeddings.position_embedding.weight