Commit 876a36a4 authored by raojy's avatar raojy
Browse files

first

parent eda2afb8
# Copyright (c) 2024 The HuggingFace Inc. team.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under Apache-2.0, with the full license text
# available at https://github.com/huggingface/transformers/blob/main/LICENSE.
#
# This modified file is released under the same license.
import torch
from flash_attn import flash_attn_varlen_func
from torch import nn
from transformers.activations import ACT2FN
from modeling.siglip.configuration_siglip import (
SiglipVisionConfig as _SiglipVisionConfig,
)
from modeling.siglip.modeling_siglip import SiglipAttention, SiglipPreTrainedModel
class SiglipVisionConfig(_SiglipVisionConfig):
r"""
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
Example:
```python
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
>>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipVisionConfig()
>>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "siglip_vision_model"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
image_size=224,
patch_size=16,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
rope=True,
**kwargs,
):
super().__init__(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_channels=num_channels,
image_size=image_size,
patch_size=patch_size,
hidden_act=hidden_act,
layer_norm_eps=layer_norm_eps,
attention_dropout=attention_dropout,
**kwargs,
)
self.rope = rope
class RotaryEmbedding2D(torch.nn.Module):
def __init__(self, dim, max_h, max_w, base=10000):
super().__init__()
freq = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
inv_freq = 1.0 / (base**freq)
grid_h = torch.arange(0, max_h)
grid_h = grid_h.to(inv_freq.dtype)
grid_h = grid_h[:, None].repeat(1, max_w)
grid_w = torch.arange(0, max_w)
grid_w = grid_w.to(inv_freq.dtype)
grid_w = grid_w[None, :].repeat(max_h, 1)
cos_h, sin_h = self._forward_one_side(grid_h, inv_freq)
cos_w, sin_w = self._forward_one_side(grid_w, inv_freq)
self.register_buffer("cos_h", cos_h)
self.register_buffer("sin_h", sin_h)
self.register_buffer("cos_w", cos_w)
self.register_buffer("sin_w", sin_w)
def _forward_one_side(self, grid, inv_freq):
freqs = grid[..., None] * inv_freq[None, None, :]
emb = torch.cat((freqs, freqs), dim=-1).flatten(0, 1)
return emb.cos(), emb.sin()
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
# unsqueeze due to the head dimension
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
if not config.rope:
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
def convert_conv2d_to_linear(self, config, meta=False):
if meta:
linear_patch_embedding = nn.Linear(
config.num_channels * self.patch_size**2,
self.embed_dim,
bias=True,
device="meta",
)
else:
linear_patch_embedding = nn.Linear(
config.num_channels * self.patch_size**2, self.embed_dim, bias=True
)
W = self.patch_embedding.weight.permute(0, 2, 3, 1).reshape(
self.embed_dim, config.num_channels * self.patch_size**2
)
linear_patch_embedding.weight.data = W
linear_patch_embedding.bias.data = self.patch_embedding.bias.data
del self.patch_embedding
self.patch_embedding = linear_patch_embedding
def forward(
self,
packed_pixel_values: torch.FloatTensor,
packed_flattened_position_ids: torch.LongTensor,
) -> torch.Tensor:
patch_embeds = self.patch_embedding(packed_pixel_values)
if not self.config.rope:
embeddings = patch_embeds + self.position_embedding(
packed_flattened_position_ids
)
else:
embeddings = patch_embeds
return embeddings
class SiglipFlashAttention2(SiglipAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
cos_h: torch.Tensor = None,
sin_h: torch.Tensor = None,
cos_w: torch.Tensor = None,
sin_w: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
total_q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(total_q_len, self.num_heads, self.head_dim)
key_states = key_states.view(total_q_len, self.num_heads, self.head_dim)
value_states = value_states.view(total_q_len, self.num_heads, self.head_dim)
if self.config.rope:
qh, qw = (
query_states[:, :, : self.head_dim // 2],
query_states[:, :, self.head_dim // 2 :],
)
kh, kw = (
key_states[:, :, : self.head_dim // 2],
key_states[:, :, self.head_dim // 2 :],
)
qh, kh = apply_rotary_pos_emb(qh, kh, cos_h, sin_h)
qw, kw = apply_rotary_pos_emb(qw, kw, cos_w, sin_w)
query_states = torch.cat([qh, qw], dim=-1)
key_states = torch.cat([kh, kw], dim=-1)
attn_output = flash_attn_varlen_func(
query_states.to(torch.bfloat16),
key_states.to(torch.bfloat16),
value_states.to(torch.bfloat16),
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=False,
)
attn_output = self.out_proj(attn_output.reshape(total_q_len, -1))
return attn_output
class SiglipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SiglipFlashAttention2(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
cos_h: torch.Tensor = None,
sin_h: torch.Tensor = None,
cos_w: torch.Tensor = None,
sin_w: torch.Tensor = None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
cos_h=cos_h,
sin_h=sin_h,
cos_w=cos_w,
sin_w=sin_w,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class SiglipEncoder(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
def forward(
self,
inputs_embeds: torch.Tensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
cos_h: torch.Tensor = None,
sin_h: torch.Tensor = None,
cos_w: torch.Tensor = None,
sin_w: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
cu_seqlens,
max_seqlen,
cos_h=cos_h,
sin_h=sin_h,
cos_w=cos_w,
sin_w=sin_w,
)
return hidden_states
class SiglipVisionTransformer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
if config.rope:
max_size = config.image_size // config.patch_size
dim_head = config.hidden_size // config.num_attention_heads
self.rope = RotaryEmbedding2D(dim_head // 2, max_size, max_size)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def forward(
self,
packed_pixel_values: torch.Tensor,
packed_flattened_position_ids: torch.LongTensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
) -> torch.Tensor:
hidden_states = self.embeddings(
packed_pixel_values=packed_pixel_values,
packed_flattened_position_ids=packed_flattened_position_ids,
)
extra_inputs = {}
if self.config.rope:
extra_inputs.update(
cos_h=self.rope.cos_h[packed_flattened_position_ids],
sin_h=self.rope.sin_h[packed_flattened_position_ids],
cos_w=self.rope.cos_w[packed_flattened_position_ids],
sin_w=self.rope.sin_w[packed_flattened_position_ids],
)
last_hidden_state = self.encoder(
inputs_embeds=hidden_states,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
**extra_inputs,
)
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
class SiglipVisionModel(SiglipPreTrainedModel):
config_class = SiglipVisionConfig
main_input_name = "packed_pixel_values"
def __init__(self, config: SiglipVisionConfig):
super().__init__(config)
self.vision_model = SiglipVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def forward(
self,
packed_pixel_values: torch.Tensor,
packed_flattened_position_ids: torch.LongTensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
) -> torch.Tensor:
return self.vision_model(
packed_pixel_values=packed_pixel_values,
packed_flattened_position_ids=packed_flattened_position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
# Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING
from transformers.utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tokenizers_available,
is_torch_available,
)
_import_structure = {
"configuration_qwen2": ["Qwen2Config"],
"tokenization_qwen2": ["Qwen2Tokenizer"],
}
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_qwen2_fast"] = ["Qwen2TokenizerFast"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_qwen2"] = [
"Qwen2ForCausalLM",
"Qwen2Model",
"Qwen2PreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_qwen2 import Qwen2Config
from .tokenization_qwen2 import Qwen2Tokenizer
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_qwen2_fast import Qwen2TokenizerFast
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_qwen2 import (
Qwen2ForCausalLM,
Qwen2Model,
Qwen2PreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__, globals()["__file__"], _import_structure, module_spec=__spec__
)
# Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
# SPDX-License-Identifier: Apache-2.0
"""Qwen2 model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Qwen2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import Qwen2Model, Qwen2Config
>>> # Initializing a Qwen2 style configuration
>>> configuration = Qwen2Config()
>>> # Initializing a model from the Qwen2-7B style configuration
>>> model = Qwen2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
is_causal=True,
_attn_implementation="flash_attention_2",
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if use_sliding_window else None
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_dropout = attention_dropout
self.is_causal = is_causal
self._attn_implementation = _attn_implementation
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
# SPDX-License-Identifier: Apache-2.0
"""PyTorch Qwen2 model."""
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_qwen2 import Qwen2Config
if is_flash_attn_2_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B"
_CONFIG_FOR_DOC = "Qwen2Config"
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[Qwen2Config] = None,
):
super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get(
"rope_type", config.rope_scaling.get("type")
)
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer(
"inv_freq", inv_freq, persistent=False
) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if (
seq_len < self.original_max_seq_len
and self.max_seq_len_cached > self.original_max_seq_len
): # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = (
device_type
if isinstance(device_type, str) and device_type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False):
freqs = (
inv_freq_expanded.float() @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
class Qwen2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
return self.down_proj(
self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)
)
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Qwen2Attention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
and "Generating Long Sequences with Sparse Transformers".
"""
def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = config.is_causal
self.attention_dropout = config.attention_dropout
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=True
)
self.k_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
)
self.v_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class Qwen2FlashAttention2(Qwen2Attention):
"""
Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
as the weights of the module stays untouched. The only required change would be on the forward pass
where it needs to correctly call the public API of flash attention and deal with padding tokens
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
config.max_window_layers layers.
"""
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
else:
sliding_window = None
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
QWEN2_ATTENTION_CLASSES = {
"eager": Qwen2Attention,
"flash_attention_2": Qwen2FlashAttention2,
}
class Qwen2DecoderLayer(nn.Module):
def __init__(self, config: Qwen2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
if config.sliding_window and config._attn_implementation != "flash_attention_2":
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
)
self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](
config, layer_idx
)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence.
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
QWEN2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Qwen2Config`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
QWEN2_START_DOCSTRING,
)
class Qwen2PreTrainedModel(PreTrainedModel):
config_class = Qwen2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen2DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
QWEN2_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
QWEN2_START_DOCSTRING,
)
class Qwen2Model(Qwen2PreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
Args:
config: Qwen2Config
"""
def __init__(self, config: Qwen2Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
Qwen2DecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self._attn_implementation = config._attn_implementation
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
if attention_mask is not None and 0.0 in attention_mask:
causal_mask = attention_mask
else:
causal_mask = None
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = Qwen2Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
>>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
# SPDX-License-Identifier: Apache-2.0
"""Tokenization classes for Qwen2."""
import json
import os
import unicodedata
from functools import lru_cache
from typing import Optional, Tuple
import regex as re
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
from transformers.utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
@lru_cache()
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
characters the bpe code barfs on.
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
def get_pairs(word):
"""
Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Qwen2Tokenizer(PreTrainedTokenizer):
"""
Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
be encoded differently whether it is at the beginning of the sentence (without space) or not:
```python
>>> from transformers import Qwen2Tokenizer
>>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
>>> tokenizer("Hello world")["input_ids"]
[9707, 1879]
>>> tokenizer(" Hello world")["input_ids"]
[21927, 1879]
```
This is expected.
You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
Path to the vocabulary file.
merges_file (`str`):
Path to the merges file.
errors (`str`, *optional*, defaults to `"replace"`):
Paradigm to follow when decoding bytes to UTF-8. See
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str`, *optional*):
The beginning of sequence token. Not applicable for this tokenizer.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The token used for padding, for example when batching sequences of different lengths.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
split_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the special tokens should be split during the tokenization process. The default behavior is
to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
'|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
"""
vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
clean_up_tokenization_spaces=False,
split_special_tokens=False,
**kwargs,
):
# Qwen vocab does not contain control tokens; added tokens need to be special
bos_token = (
AddedToken(
bos_token, lstrip=False, rstrip=False, special=True, normalized=False
)
if isinstance(bos_token, str)
else bos_token
)
eos_token = (
AddedToken(
eos_token, lstrip=False, rstrip=False, special=True, normalized=False
)
if isinstance(eos_token, str)
else eos_token
)
unk_token = (
AddedToken(
unk_token, lstrip=False, rstrip=False, special=True, normalized=False
)
if isinstance(unk_token, str)
else unk_token
)
pad_token = (
AddedToken(
pad_token, lstrip=False, rstrip=False, special=True, normalized=False
)
if isinstance(pad_token, str)
else pad_token
)
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_merges = []
with open(merges_file, encoding="utf-8") as merges_handle:
for i, line in enumerate(merges_handle):
line = line.strip()
if (i == 0 and line.startswith("#version:")) or not line:
continue
bpe_merges.append(tuple(line.split()))
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
# NOTE: the cache can grow without bound and will get really large for long running processes
# (esp. for texts of language that do not use space between word, e.g. Chinese); technically
# not a memory leak but appears as one.
# GPT2Tokenizer has the same problem, so let's be consistent.
self.cache = {}
self.pat = re.compile(PRETOKENIZE_REGEX)
if kwargs.get("add_prefix_space", False):
logger.warning_once(
f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
)
super().__init__(
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
unk_token=unk_token,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
split_special_tokens=split_special_tokens,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self.encoder)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
def _tokenize(self, text):
"""Tokenize a string."""
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(
self.byte_encoder[b] for b in token.encode("utf-8")
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode(
"utf-8", errors=self.errors
)
return text
def decode(
self,
token_ids,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: Optional[bool] = False,
spaces_between_special_tokens: bool = False,
**kwargs,
) -> str:
# `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
# and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
return super().decode(
token_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
spaces_between_special_tokens=spaces_between_special_tokens,
**kwargs,
)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "")
+ VOCAB_FILES_NAMES["vocab_file"],
)
merge_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "")
+ VOCAB_FILES_NAMES["merges_file"],
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(
json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False)
+ "\n"
)
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(
self.bpe_ranks.items(), key=lambda kv: kv[1]
):
if index != token_index:
logger.warning(
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index = token_index
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
return vocab_file, merge_file
def prepare_for_tokenization(self, text, **kwargs):
text = unicodedata.normalize("NFC", text)
return (text, kwargs)
# Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
# SPDX-License-Identifier: Apache-2.0
"""Tokenization classes for Qwen2."""
from typing import Optional, Tuple
from transformers.tokenization_utils import AddedToken
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from transformers.utils import logging
from .tokenization_qwen2 import Qwen2Tokenizer
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
"tokenizer_file": "tokenizer.json",
}
MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
class Qwen2TokenizerFast(PreTrainedTokenizerFast):
"""
Construct a "fast" Qwen2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
Byte-Pair-Encoding.
Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
be encoded differently whether it is at the beginning of the sentence (without space) or not:
```python
>>> from transformers import Qwen2TokenizerFast
>>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer")
>>> tokenizer("Hello world")["input_ids"]
[9707, 1879]
>>> tokenizer(" Hello world")["input_ids"]
[21927, 1879]
```
This is expected.
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
Args:
vocab_file (`str`, *optional*):
Path to the vocabulary file.
merges_file (`str`, *optional*):
Path to the merges file.
tokenizer_file (`str`, *optional*):
Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
contains everything needed to load the tokenizer.
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead. Not applicable to this tokenizer.
bos_token (`str`, *optional*):
The beginning of sequence token. Not applicable for this tokenizer.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The token used for padding, for example when batching sequences of different lengths.
"""
vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
slow_tokenizer_class = Qwen2Tokenizer
def __init__(
self,
vocab_file=None,
merges_file=None,
tokenizer_file=None,
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
**kwargs,
):
# We need to at least pass vocab_file and merges_file to base class
# in case a slow tokenizer needs to be initialized; other can be
# configured through files.
# following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token
bos_token = (
AddedToken(
bos_token, lstrip=False, rstrip=False, special=True, normalized=False
)
if isinstance(bos_token, str)
else bos_token
)
eos_token = (
AddedToken(
eos_token, lstrip=False, rstrip=False, special=True, normalized=False
)
if isinstance(eos_token, str)
else eos_token
)
unk_token = (
AddedToken(
unk_token, lstrip=False, rstrip=False, special=True, normalized=False
)
if isinstance(unk_token, str)
else unk_token
)
pad_token = (
AddedToken(
pad_token, lstrip=False, rstrip=False, special=True, normalized=False
)
if isinstance(pad_token, str)
else pad_token
)
super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
tokenizer_file=tokenizer_file,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
**kwargs,
)
# Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> Tuple[str]:
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
# Copyright 2024 The HuggingFace Inc. team.
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING
from transformers.utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_torch_available,
is_vision_available,
)
_import_structure = {
"configuration_siglip": [
"SiglipConfig",
"SiglipTextConfig",
"SiglipVisionConfig",
],
"processing_siglip": ["SiglipProcessor"],
}
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_siglip"] = ["SiglipTokenizer"]
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_siglip"] = ["SiglipImageProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_siglip"] = [
"SiglipModel",
"SiglipPreTrainedModel",
"SiglipTextModel",
"SiglipVisionModel",
"SiglipForImageClassification",
]
if TYPE_CHECKING:
from .configuration_siglip import (
SiglipConfig,
SiglipTextConfig,
SiglipVisionConfig,
)
from .processing_siglip import SiglipProcessor
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_siglip import SiglipTokenizer
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_siglip import SiglipImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_siglip import (
SiglipForImageClassification,
SiglipModel,
SiglipPreTrainedModel,
SiglipTextModel,
SiglipVisionModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__, globals()["__file__"], _import_structure, module_spec=__spec__
)
# Copyright 2024 The HuggingFace Inc. team.
# SPDX-License-Identifier: Apache-2.0
"""Siglip model configuration"""
import os
from typing import Union
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class SiglipTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
the `inputs_ids` passed when calling [`SiglipModel`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
max_position_embeddings (`int`, *optional*, defaults to 64):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
pad_token_id (`int`, *optional*, defaults to 1):
The id of the padding token in the vocabulary.
bos_token_id (`int`, *optional*, defaults to 49406):
The id of the beginning-of-sequence token in the vocabulary.
eos_token_id (`int`, *optional*, defaults to 49407):
The id of the end-of-sequence token in the vocabulary.
Example:
```python
>>> from transformers import SiglipTextConfig, SiglipTextModel
>>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipTextConfig()
>>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "siglip_text_model"
def __init__(
self,
vocab_size=32000,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=64,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
# This differs from `CLIPTokenizer`'s default and from openai/siglip
# See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
pad_token_id=1,
bos_token_id=49406,
eos_token_id=49407,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.attention_dropout = attention_dropout
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
# get the text config dict if we are loading from SiglipConfig
if config_dict.get("model_type") == "siglip":
config_dict = config_dict["text_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class SiglipVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
Example:
```python
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
>>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipVisionConfig()
>>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "siglip_vision_model"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
image_size=224,
patch_size=16,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
# get the vision config dict if we are loading from SiglipConfig
if config_dict.get("model_type") == "siglip":
config_dict = config_dict["vision_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class SiglipConfig(PretrainedConfig):
r"""
[`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`SiglipTextConfig`].
vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
kwargs (*optional*):
Dictionary of keyword arguments.
Example:
```python
>>> from transformers import SiglipConfig, SiglipModel
>>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipConfig()
>>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
>>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
>>> from transformers import SiglipTextConfig, SiglipVisionConfig
>>> # Initializing a SiglipText and SiglipVision configuration
>>> config_text = SiglipTextConfig()
>>> config_vision = SiglipVisionConfig()
>>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
```"""
model_type = "siglip"
def __init__(self, text_config=None, vision_config=None, **kwargs):
super().__init__(**kwargs)
if text_config is None:
text_config = {}
logger.info(
"`text_config` is `None`. Initializing the `SiglipTextConfig` with default values."
)
if vision_config is None:
vision_config = {}
logger.info(
"`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values."
)
self.text_config = SiglipTextConfig(**text_config)
self.vision_config = SiglipVisionConfig(**vision_config)
self.initializer_factor = 1.0
@classmethod
def from_text_vision_configs(
cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs
):
r"""
Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
model configuration.
Returns:
[`SiglipConfig`]: An instance of a configuration object
"""
return cls(
text_config=text_config.to_dict(),
vision_config=vision_config.to_dict(),
**kwargs,
)
# Copyright 2024 The HuggingFace Inc. team.
# SPDX-License-Identifier: Apache-2.0
"""Convert SigLIP checkpoints from the original repository.
URL: https://github.com/google-research/big_vision/tree/main
"""
import argparse
import collections
from pathlib import Path
import numpy as np
import requests
import torch
from huggingface_hub import hf_hub_download
from numpy import load
from PIL import Image
from transformers import (
SiglipConfig,
SiglipImageProcessor,
SiglipModel,
SiglipProcessor,
SiglipTokenizer,
)
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
model_name_to_checkpoint = {
# base checkpoints
"siglip-base-patch16-224": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_224_63724782.npz",
"siglip-base-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_256_60500360.npz",
"siglip-base-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_384_68578854.npz",
"siglip-base-patch16-512": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_512_68580893.npz",
# large checkpoints
"siglip-large-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_256_60552751.npz",
"siglip-large-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_384_63634585.npz",
# multilingual checkpoint
"siglip-base-patch16-256-i18n": "/Users/nielsrogge/Documents/SigLIP/webli_i18n_b16_256_66117334.npz",
# so400m checkpoints
"siglip-so400m-patch14-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_so400m_384_58765454.npz",
}
model_name_to_image_size = {
"siglip-base-patch16-224": 224,
"siglip-base-patch16-256": 256,
"siglip-base-patch16-384": 384,
"siglip-base-patch16-512": 512,
"siglip-large-patch16-256": 256,
"siglip-large-patch16-384": 384,
"siglip-base-patch16-256-i18n": 256,
"siglip-so400m-patch14-384": 384,
}
def get_siglip_config(model_name):
config = SiglipConfig()
vocab_size = 250000 if "i18n" in model_name else 32000
image_size = model_name_to_image_size[model_name]
patch_size = 16 if "patch16" in model_name else 14
# size of the architecture
config.vision_config.image_size = image_size
config.vision_config.patch_size = patch_size
config.text_config.vocab_size = vocab_size
if "base" in model_name:
pass
elif "large" in model_name:
config.text_config.hidden_size = 1024
config.text_config.intermediate_size = 4096
config.text_config.num_hidden_layers = 24
config.text_config.num_attention_heads = 16
config.vision_config.hidden_size = 1024
config.vision_config.intermediate_size = 4096
config.vision_config.num_hidden_layers = 24
config.vision_config.num_attention_heads = 16
elif "so400m" in model_name:
config.text_config.hidden_size = 1152
config.text_config.intermediate_size = 4304
config.text_config.num_hidden_layers = 27
config.text_config.num_attention_heads = 16
config.vision_config.hidden_size = 1152
config.vision_config.intermediate_size = 4304
config.vision_config.num_hidden_layers = 27
config.vision_config.num_attention_heads = 16
else:
raise ValueError("Model not supported")
return config
def create_rename_keys(config):
rename_keys = []
# fmt: off
# vision encoder
rename_keys.append(("params/img/embedding/kernel", "vision_model.embeddings.patch_embedding.weight"))
rename_keys.append(("params/img/embedding/bias", "vision_model.embeddings.patch_embedding.bias"))
rename_keys.append(("params/img/pos_embedding", "vision_model.embeddings.position_embedding.weight"))
for i in range(config.vision_config.num_hidden_layers):
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/scale", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/scale", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"vision_model.encoder.layers.{i}.self_attn.k_proj.weight"))
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"vision_model.encoder.layers.{i}.self_attn.k_proj.bias"))
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"vision_model.encoder.layers.{i}.self_attn.v_proj.weight"))
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"vision_model.encoder.layers.{i}.self_attn.v_proj.bias"))
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"vision_model.encoder.layers.{i}.self_attn.q_proj.weight"))
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"vision_model.encoder.layers.{i}.self_attn.q_proj.bias"))
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"vision_model.encoder.layers.{i}.self_attn.out_proj.weight"))
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"vision_model.encoder.layers.{i}.self_attn.out_proj.bias"))
rename_keys.append(("params/img/Transformer/encoder_norm/scale", "vision_model.post_layernorm.weight"))
rename_keys.append(("params/img/Transformer/encoder_norm/bias", "vision_model.post_layernorm.bias"))
rename_keys.append(("params/img/MAPHead_0/probe", "vision_model.head.probe"))
rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/scale", "vision_model.head.layernorm.weight"))
rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/bias", "vision_model.head.layernorm.bias"))
rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/kernel", "vision_model.head.mlp.fc1.weight"))
rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/bias", "vision_model.head.mlp.fc1.bias"))
rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/kernel", "vision_model.head.mlp.fc2.weight"))
rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/bias", "vision_model.head.mlp.fc2.bias"))
rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/kernel", "vision_model.head.attention.out_proj.weight"))
rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/bias", "vision_model.head.attention.out_proj.bias"))
# text encoder
rename_keys.append(("params/txt/Embed_0/embedding", "text_model.embeddings.token_embedding.weight"))
rename_keys.append(("params/txt/pos_embedding", "text_model.embeddings.position_embedding.weight"))
for i in range(config.text_config.num_hidden_layers):
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/scale", f"text_model.encoder.layers.{i}.layer_norm1.weight"))
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/bias", f"text_model.encoder.layers.{i}.layer_norm1.bias"))
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/scale", f"text_model.encoder.layers.{i}.layer_norm2.weight"))
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/bias", f"text_model.encoder.layers.{i}.layer_norm2.bias"))
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"text_model.encoder.layers.{i}.mlp.fc1.weight"))
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"text_model.encoder.layers.{i}.mlp.fc1.bias"))
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"text_model.encoder.layers.{i}.mlp.fc2.weight"))
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"text_model.encoder.layers.{i}.mlp.fc2.bias"))
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"text_model.encoder.layers.{i}.self_attn.k_proj.weight"))
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"text_model.encoder.layers.{i}.self_attn.k_proj.bias"))
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"text_model.encoder.layers.{i}.self_attn.v_proj.weight"))
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"text_model.encoder.layers.{i}.self_attn.v_proj.bias"))
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"text_model.encoder.layers.{i}.self_attn.q_proj.weight"))
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"text_model.encoder.layers.{i}.self_attn.q_proj.bias"))
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"text_model.encoder.layers.{i}.self_attn.out_proj.weight"))
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"text_model.encoder.layers.{i}.self_attn.out_proj.bias"))
rename_keys.append(("params/txt/Encoder_0/encoder_norm/scale", "text_model.final_layer_norm.weight"))
rename_keys.append(("params/txt/Encoder_0/encoder_norm/bias", "text_model.final_layer_norm.bias"))
rename_keys.append(("params/txt/head/kernel", "text_model.head.weight"))
rename_keys.append(("params/txt/head/bias", "text_model.head.bias"))
# learned temperature and bias
rename_keys.append(("params/t", "logit_scale"))
rename_keys.append(("params/b", "logit_bias"))
# fmt: on
return rename_keys
def rename_key(dct, old, new, config):
val = dct.pop(old)
if (
"out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new
) and "vision" in new:
val = val.reshape(-1, config.vision_config.hidden_size)
if (
"out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new
) and "text" in new:
val = val.reshape(-1, config.text_config.hidden_size)
if "patch_embedding.weight" in new:
val = val.transpose(3, 2, 0, 1)
elif (
new.endswith("weight")
and "position_embedding" not in new
and "token_embedding" not in new
):
val = val.T
if "position_embedding" in new and "vision" in new:
val = val.reshape(-1, config.vision_config.hidden_size)
if "position_embedding" in new and "text" in new:
val = val.reshape(-1, config.text_config.hidden_size)
if new.endswith("bias"):
val = val.reshape(-1)
dct[new] = torch.from_numpy(val)
def read_in_q_k_v_head(state_dict, config):
# read in individual input projection layers
key_proj_weight = (
state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/kernel")
.reshape(-1, config.vision_config.hidden_size)
.T
)
key_proj_bias = state_dict.pop(
"params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/bias"
).reshape(-1)
value_proj_weight = (
state_dict.pop(
"params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/kernel"
)
.reshape(-1, config.vision_config.hidden_size)
.T
)
value_proj_bias = state_dict.pop(
"params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/bias"
).reshape(-1)
query_proj_weight = (
state_dict.pop(
"params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/kernel"
)
.reshape(-1, config.vision_config.hidden_size)
.T
)
query_proj_bias = state_dict.pop(
"params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/bias"
).reshape(-1)
# next, add them to the state dict as a single matrix + vector
state_dict["vision_model.head.attention.in_proj_weight"] = torch.from_numpy(
np.concatenate([query_proj_weight, key_proj_weight, value_proj_weight], axis=0)
)
state_dict["vision_model.head.attention.in_proj_bias"] = torch.from_numpy(
np.concatenate([query_proj_bias, key_proj_bias, value_proj_bias], axis=0)
)
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
return image
def flatten_nested_dict(params, parent_key="", sep="/"):
items = []
for k, v in params.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.abc.MutableMapping):
items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
@torch.no_grad()
def convert_siglip_checkpoint(
model_name, pytorch_dump_folder_path, verify_logits=True, push_to_hub=False
):
"""
Copy/paste/tweak model's weights to our SigLIP structure.
"""
# define default SigLIP configuration
config = get_siglip_config(model_name)
# get checkpoint
checkpoint = model_name_to_checkpoint[model_name]
# get vocab file
if "i18n" in model_name:
vocab_file = (
"/Users/nielsrogge/Documents/SigLIP/multilingual_vocab/sentencepiece.model"
)
else:
vocab_file = (
"/Users/nielsrogge/Documents/SigLIP/english_vocab/sentencepiece.model"
)
# load original state dict
data = load(checkpoint)
state_dict = flatten_nested_dict(data)
# remove and rename some keys
rename_keys = create_rename_keys(config)
for src, dest in rename_keys:
rename_key(state_dict, src, dest, config)
# qkv matrices of attention pooling head need special treatment
read_in_q_k_v_head(state_dict, config)
# load HuggingFace model
model = SiglipModel(config).eval()
model.load_state_dict(state_dict)
# create processor
# important: make tokenizer not return attention_mask since original one doesn't require it
image_size = config.vision_config.image_size
size = {"height": image_size, "width": image_size}
image_processor = SiglipImageProcessor(size=size)
tokenizer = SiglipTokenizer(vocab_file=vocab_file, model_input_names=["input_ids"])
processor = SiglipProcessor(image_processor=image_processor, tokenizer=tokenizer)
# verify on dummy images and texts
url_1 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg"
image_1 = Image.open(requests.get(url_1, stream=True).raw).convert("RGB")
url_2 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-blank.jpg"
image_2 = Image.open(requests.get(url_2, stream=True).raw).convert("RGB")
texts = ["an apple", "a picture of an apple"]
inputs = processor(
images=[image_1, image_2], text=texts, return_tensors="pt", padding="max_length"
)
# verify input_ids against original ones
if image_size == 224:
filename = "siglip_pixel_values.pt"
elif image_size == 256:
filename = "siglip_pixel_values_256.pt"
elif image_size == 384:
filename = "siglip_pixel_values_384.pt"
elif image_size == 512:
filename = "siglip_pixel_values_512.pt"
else:
raise ValueError("Image size not supported")
filepath = hf_hub_download(
repo_id="nielsr/test-image", filename=filename, repo_type="dataset"
)
original_pixel_values = torch.load(filepath)
filepath = hf_hub_download(
repo_id="nielsr/test-image", filename="siglip_input_ids.pt", repo_type="dataset"
)
original_input_ids = torch.load(filepath)
if "i18n" not in model_name:
assert inputs.input_ids.tolist() == original_input_ids.tolist()
print("Mean of original pixel values:", original_pixel_values.mean())
print("Mean of new pixel values:", inputs.pixel_values.mean())
# note: we're testing with original pixel values here since we don't have exact pixel values
with torch.no_grad():
outputs = model(input_ids=inputs.input_ids, pixel_values=original_pixel_values)
# with torch.no_grad():
# outputs = model(input_ids=inputs.input_ids, pixel_values=inputs.pixel_values)
print(outputs.logits_per_image[:3, :3])
probs = torch.sigmoid(outputs.logits_per_image) # these are the probabilities
print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
print(f"{probs[0][1]:.1%} that image 0 is '{texts[1]}'")
if verify_logits:
if model_name == "siglip-base-patch16-224":
expected_slice = torch.tensor(
[[-2.9621, -2.1672], [-0.2713, 0.2910]],
)
elif model_name == "siglip-base-patch16-256":
expected_slice = torch.tensor(
[[-3.1146, -1.9894], [-0.7312, 0.6387]],
)
elif model_name == "siglip-base-patch16-384":
expected_slice = torch.tensor(
[[-2.8098, -2.1891], [-0.4242, 0.4102]],
)
elif model_name == "siglip-base-patch16-512":
expected_slice = torch.tensor(
[[-2.7899, -2.2668], [-0.4295, -0.0735]],
)
elif model_name == "siglip-large-patch16-256":
expected_slice = torch.tensor(
[[-1.5827, -0.5801], [-0.9153, 0.1363]],
)
elif model_name == "siglip-large-patch16-384":
expected_slice = torch.tensor(
[[-2.1523, -0.2899], [-0.2959, 0.7884]],
)
elif model_name == "siglip-so400m-patch14-384":
expected_slice = torch.tensor([[-1.2441, -0.6649], [-0.7060, 0.7374]])
elif model_name == "siglip-base-patch16-256-i18n":
expected_slice = torch.tensor(
[[-0.9064, 0.1073], [-0.0299, 0.5304]],
)
assert torch.allclose(
outputs.logits_per_image[:3, :3], expected_slice, atol=1e-4
)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving processor to {pytorch_dump_folder_path}")
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
model.push_to_hub(f"nielsr/{model_name}")
processor.push_to_hub(f"nielsr/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="siglip-base-patch16-224",
type=str,
choices=model_name_to_checkpoint.keys(),
help="Name of the model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--verify_logits",
action="store_false",
help="Whether to verify logits against the original implementation.",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether or not to push the converted model to the 🤗 hub.",
)
args = parser.parse_args()
convert_siglip_checkpoint(
args.model_name,
args.pytorch_dump_folder_path,
args.verify_logits,
args.push_to_hub,
)
# Copyright 2024 The HuggingFace Inc. team.
# SPDX-License-Identifier: Apache-2.0
"""Image processor class for SigLIP."""
from typing import Dict, List, Optional, Union
from transformers.image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_size_dict,
)
from transformers.image_transforms import (
convert_to_rgb,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from transformers.utils import (
TensorType,
filter_out_non_signature_kwargs,
is_vision_available,
logging,
)
logger = logging.get_logger(__name__)
if is_vision_available():
import PIL
class SiglipImageProcessor(BaseImageProcessor):
r"""
Constructs a SigLIP image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
`do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
@filter_out_non_signature_kwargs()
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
do_convert_rgb: bool = None,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = (
rescale_factor if rescale_factor is not None else self.rescale_factor
)
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = (
do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
)
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize:
height, width = size["height"], size["width"]
images = [
resize(
image=image,
size=(height, width),
resample=resample,
input_data_format=input_data_format,
)
for image in images
]
if do_rescale:
images = [
self.rescale(
image=image,
scale=rescale_factor,
input_data_format=input_data_format,
)
for image in images
]
if do_normalize:
images = [
self.normalize(
image=image,
mean=image_mean,
std=image_std,
input_data_format=input_data_format,
)
for image in images
]
images = [
to_channel_dimension_format(
image, data_format, input_channel_dim=input_data_format
)
for image in images
]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
# Copyright 2024 The HuggingFace Inc. team.
# SPDX-License-Identifier: Apache-2.0
"""PyTorch Siglip model."""
import math
import warnings
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from liger_kernel.transformers import LigerCrossEntropyLoss
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn.init import _calculate_fan_in_and_fan_out
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
if is_flash_attn_2_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "SiglipConfig"
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(
tensor: torch.Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
) -> torch.Tensor:
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \\leq \text{mean} \\leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsequently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
class SiglipVisionModelOutput(ModelOutput):
"""
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
Args:
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
class SiglipTextModelOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the last hidden states.
Args:
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The text embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
text_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
class SiglipOutput(ModelOutput):
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
text_model_output (`BaseModelOutputWithPooling`):
The output of the [`SiglipTextModel`].
vision_model_output (`BaseModelOutputWithPooling`):
The output of the [`SiglipVisionModel`].
"""
loss: Optional[torch.FloatTensor] = None
logits_per_image: torch.FloatTensor = None
logits_per_text: torch.FloatTensor = None
text_embeds: torch.FloatTensor = None
image_embeds: torch.FloatTensor = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None
def to_tuple(self) -> Tuple[Any]:
return tuple(
self[k]
if k not in ["text_model_output", "vision_model_output"]
else getattr(self, k).to_tuple()
for k in self.keys()
)
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False,
)
def interpolate_pos_encoding(
self, embeddings: torch.Tensor, height: int, width: int
) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing and no class embeddings.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
num_patches = embeddings.shape[1]
num_positions = self.position_embedding.weight.shape[0]
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if (
not torch.jit.is_tracing()
and num_patches == num_positions
and height == width
):
return self.position_embedding(self.position_ids)
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
dim = embeddings.shape[-1]
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(
1, sqrt_num_positions, sqrt_num_positions, dim
)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
def forward(
self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False
) -> torch.Tensor:
_, _, height, width = pixel_values.shape
patch_embeds = self.patch_embedding(
pixel_values
) # shape = [*, width, grid, grid]
embeddings = patch_embeds.flatten(2).transpose(1, 2)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(
embeddings, height, width
)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
class SiglipTextEmbeddings(nn.Module):
def __init__(self, config: SiglipTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(
config.max_position_embeddings, embed_dim
)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).expand((1, -1)),
persistent=False,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = (
input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
class SiglipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
batch_size, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
batch_size, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
)
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class SiglipFlashAttention2(SiglipAttention):
"""
SiglipAttention flash attention module. This module inherits from `SiglipAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
is_causal = False
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(
batch_size, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
batch_size, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output = attn_output.reshape(
batch_size, q_len, self.embed_dim
).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
class SiglipSdpaAttention(SiglipAttention):
"""
Siglip attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`SiglipAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
is_causal = False
# Adapted from SiglipAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"SiglipModel is using SiglipSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
batch_size, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
batch_size, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if self.is_causal and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
SIGLIP_ATTENTION_CLASSES = {
"eager": SiglipAttention,
"flash_attention_2": SiglipFlashAttention2,
"sdpa": SiglipSdpaAttention,
}
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
class SiglipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SIGLIP_ATTENTION_CLASSES[config._attn_implementation](
config=config
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class SiglipPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = SiglipConfig
base_model_prefix = "siglip"
supports_gradient_checkpointing = True
_no_split_modules = [
"SiglipTextEmbeddings",
"SiglipEncoderLayer",
"SiglipVisionEmbeddings",
"SiglipEncoderLayer",
"SiglipMultiheadAttentionPoolingHead",
]
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, SiglipVisionEmbeddings):
width = (
self.config.vision_config.hidden_size
if isinstance(self.config, SiglipConfig)
else self.config.hidden_size
)
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
elif isinstance(module, nn.Embedding):
default_flax_embed_init(module.weight)
elif isinstance(module, SiglipAttention):
nn.init.xavier_uniform_(module.q_proj.weight)
nn.init.xavier_uniform_(module.k_proj.weight)
nn.init.xavier_uniform_(module.v_proj.weight)
nn.init.xavier_uniform_(module.out_proj.weight)
nn.init.zeros_(module.q_proj.bias)
nn.init.zeros_(module.k_proj.bias)
nn.init.zeros_(module.v_proj.bias)
nn.init.zeros_(module.out_proj.bias)
elif isinstance(module, SiglipMLP):
nn.init.xavier_uniform_(module.fc1.weight)
nn.init.xavier_uniform_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
nn.init.xavier_uniform_(module.probe.data)
nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
nn.init.zeros_(module.attention.in_proj_bias.data)
elif isinstance(module, SiglipModel):
logit_scale_init = torch.log(torch.tensor(1.0))
module.logit_scale.data.fill_(logit_scale_init)
module.logit_bias.data.zero_()
elif isinstance(module, SiglipForImageClassification):
nn.init.normal_(
module.classifier.weight,
std=self.config.vision_config.hidden_size**-0.5
* self.config.initializer_factor,
)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
SIGLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
SIGLIP_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
SIGLIP_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip
class SiglipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`SiglipEncoderLayer`].
Args:
config: SiglipConfig
"""
def __init__(self, config: SiglipConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
# Ignore copy
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, encoder_states, all_attentions]
if v is not None
)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
class SiglipTextTransformer(nn.Module):
def __init__(self, config: SiglipTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipTextEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.head = nn.Linear(embed_dim, embed_dim)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is None:
raise ValueError("You have to specify input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
# note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
# expand attention_mask
if attention_mask is not None and not self._use_flash_attention_2:
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(
attention_mask, hidden_states.dtype
)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
# Assuming "sticky" EOS tokenization, last token is always EOS.
pooled_output = last_hidden_state[:, -1, :]
pooled_output = self.head(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"""The text model from SigLIP without any head or projection on top.""",
SIGLIP_START_DOCSTRING,
)
class SiglipTextModel(SiglipPreTrainedModel):
config_class = SiglipTextConfig
def __init__(self, config: SiglipTextConfig):
super().__init__(config)
self.text_model = SiglipTextTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, value):
self.text_model.embeddings.token_embedding = value
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Examples:
```python
>>> from transformers import AutoTokenizer, SiglipTextModel
>>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
>>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
>>> # important: make sure to set padding="max_length" as that's how the model was trained
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class SiglipVisionTransformer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.use_head = (
True if not hasattr(config, "vision_use_head") else config.vision_use_head
)
if self.use_head:
self.head = SiglipMultiheadAttentionPoolingHead(config)
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig
)
def forward(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
hidden_states = self.embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
pooler_output = self.head(last_hidden_state) if self.use_head else None
if not return_dict:
return (last_hidden_state, pooler_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class SiglipMultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling."""
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.attention = torch.nn.MultiheadAttention(
config.hidden_size, config.num_attention_heads, batch_first=True
)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
def forward(self, hidden_state):
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
@add_start_docstrings(
"""The vision model from SigLIP without any head or projection on top.""",
SIGLIP_START_DOCSTRING,
)
class SiglipVisionModel(SiglipPreTrainedModel):
config_class = SiglipVisionConfig
main_input_name = "pixel_values"
def __init__(self, config: SiglipVisionConfig):
super().__init__(config)
self.vision_model = SiglipVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig
)
def forward(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, SiglipVisionModel
>>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled features
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
@add_start_docstrings(SIGLIP_START_DOCSTRING)
class SiglipModel(SiglipPreTrainedModel):
config_class = SiglipConfig
def __init__(self, config: SiglipConfig):
super().__init__(config)
if not isinstance(config.text_config, SiglipTextConfig):
raise TypeError(
"config.text_config is expected to be of type SiglipTextConfig but is of type"
f" {type(config.text_config)}."
)
if not isinstance(config.vision_config, SiglipVisionConfig):
raise TypeError(
"config.vision_config is expected to be of type SiglipVisionConfig but is of type"
f" {type(config.vision_config)}."
)
text_config = config.text_config
vision_config = config.vision_config
# First, initialize the text and vision models with proper attention implementation
text_model = SiglipTextModel._from_config(text_config)
vision_model = SiglipVisionModel._from_config(vision_config)
# Second, get the text and vision submodules (for backward compatibility)
self.text_model = text_model.text_model
self.vision_model = vision_model.vision_model
self.logit_scale = nn.Parameter(torch.randn(1))
self.logit_bias = nn.Parameter(torch.randn(1))
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
def get_text_features(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Returns:
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
applying the projection layer to the pooled output of [`SiglipTextModel`].
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
>>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
>>> # important: make sure to set padding="max_length" as that's how the model was trained
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
>>> with torch.no_grad():
... text_features = model.get_text_features(**inputs)
```"""
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = text_outputs[1]
return pooled_output
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> torch.FloatTensor:
r"""
Returns:
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
applying the projection layer to the pooled output of [`SiglipVisionModel`].
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> with torch.no_grad():
... image_features = model.get_image_features(**inputs)
```"""
# Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
pooled_output = vision_outputs[1]
return pooled_output
@add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, SiglipOutput]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
>>> # important: we pass `padding=max_length` since the model was trained with this
>>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
>>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
31.9% that image 0 is 'a photo of 2 cats'
```"""
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
text_embeds = text_outputs[1]
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_text = (
torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
* self.logit_scale.exp()
+ self.logit_bias
)
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
# Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287
eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
nll = -torch.sum(loglik, dim=-1)
loss = nll.mean()
if not return_dict:
output = (
logits_per_image,
logits_per_text,
text_embeds,
image_embeds,
text_outputs,
vision_outputs,
)
return ((loss,) + output) if loss is not None else output
return SiglipOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
@add_start_docstrings(
"""
SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
the patch tokens) e.g. for ImageNet.
""",
SIGLIP_START_DOCSTRING,
)
class SiglipForImageClassification(SiglipPreTrainedModel):
main_input_name = "pixel_values"
def __init__(self, config: SiglipConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
# Create the vision model with proper attention
# and take only vision_model submodule (for backward compatibility)
vision_model = SiglipVisionModel._from_config(config.vision_config)
self.vision_model = vision_model.vision_model
# Classifier head
self.classifier = (
nn.Linear(config.vision_config.hidden_size, config.num_labels)
if config.num_labels > 0
else nn.Identity()
)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, SiglipForImageClassification
>>> import torch
>>> from PIL import Image
>>> import requests
>>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> # note: we are loading a `SiglipModel` from the hub here,
>>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
>>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224")
>>> inputs = image_processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the two classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: LABEL_1
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.vision_model(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
sequence_output = outputs[0]
# average pool the patch tokens
sequence_output = torch.mean(sequence_output, dim=1)
# apply classifier
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
# loss_fct = CrossEntropyLoss()
loss_fct = LigerCrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# Copyright 2024 The HuggingFace Inc. team.
# SPDX-License-Identifier: Apache-2.0
"""
Image/Text processor class for SigLIP.
"""
from typing import List, Optional, Union
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import (
PaddingStrategy,
PreTokenizedInput,
TextInput,
TruncationStrategy,
)
from transformers.utils import TensorType
class SiglipProcessor(ProcessorMixin):
r"""
Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor.
[`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the
[`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information.
Args:
image_processor ([`SiglipImageProcessor`]):
The image processor is a required input.
tokenizer ([`SiglipTokenizer`]):
The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "SiglipImageProcessor"
tokenizer_class = "SiglipTokenizer"
def __init__(self, image_processor, tokenizer):
super().__init__(image_processor, tokenizer)
def __call__(
self,
text: Union[
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
] = None,
images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: int = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` argument to
SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if text is None and images is None:
raise ValueError(
"You have to specify either text or images. Both cannot be none."
)
if text is not None:
encoding = self.tokenizer(
text,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors)
if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
return encoding
elif text is not None:
return encoding
else:
return BatchFeature(data=dict(**image_features), tensor_type=return_tensors)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
# Copyright 2024 The HuggingFace Inc. team.
# SPDX-License-Identifier: Apache-2.0
"""Tokenization class for SigLIP model."""
import os
import re
import string
import warnings
from shutil import copyfile
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from transformers.convert_slow_tokenizer import import_protobuf
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_base import AddedToken
if TYPE_CHECKING:
from transformers.tokenization_utils_base import TextInput
from transformers.utils import logging, requires_backends
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
SPIECE_UNDERLINE = "▁"
class SiglipTokenizer(PreTrainedTokenizer):
"""
Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
contains the vocabulary necessary to instantiate a tokenizer.
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (`str`, *optional*, defaults to `"</s>"`):
The token used for padding, for example when batching sequences of different lengths.
additional_special_tokens (`List[str]`, *optional*):
Additional special tokens used by the tokenizer.
sp_model_kwargs (`dict`, *optional*):
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
to set:
- `enable_sampling`: Enable subword regularization.
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
- `nbest_size = {0,1}`: No sampling is performed.
- `nbest_size > 1`: samples from the nbest_size results.
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
using forward-filtering-and-backward-sampling algorithm.
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
BPE-dropout.
model_max_length (`int`, *optional*, defaults to 64):
The maximum length (in number of tokens) for model inputs.
do_lower_case (`bool`, *optional*, defaults to `True`):
Whether or not to lowercase the input when tokenizing.
"""
vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
eos_token="</s>",
unk_token="<unk>",
pad_token="</s>",
additional_special_tokens=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
model_max_length=64,
do_lower_case=True,
**kwargs,
) -> None:
requires_backends(self, "protobuf")
pad_token = (
AddedToken(
pad_token, rstrip=True, lstrip=True, normalized=False, special=True
)
if isinstance(pad_token, str)
else pad_token
)
unk_token = (
AddedToken(
unk_token, rstrip=True, lstrip=True, normalized=False, special=True
)
if isinstance(unk_token, str)
else unk_token
)
eos_token = (
AddedToken(
eos_token, rstrip=True, lstrip=True, normalized=False, special=True
)
if isinstance(eos_token, str)
else eos_token
)
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.do_lower_case = do_lower_case
self.vocab_file = vocab_file
self.sp_model = self.get_spm_processor()
self.vocab_file = vocab_file
super().__init__(
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
additional_special_tokens=additional_special_tokens,
sp_model_kwargs=self.sp_model_kwargs,
model_max_length=model_max_length,
do_lower_case=do_lower_case,
**kwargs,
)
def get_spm_processor(self):
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
with open(self.vocab_file, "rb") as f:
sp_model = f.read()
model_pb2 = import_protobuf()
model = model_pb2.ModelProto.FromString(sp_model)
normalizer_spec = model_pb2.NormalizerSpec()
normalizer_spec.add_dummy_prefix = False
model.normalizer_spec.MergeFrom(normalizer_spec)
sp_model = model.SerializeToString()
tokenizer.LoadFromSerializedProto(sp_model)
return tokenizer
@property
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size
def vocab_size(self):
return self.sp_model.get_piece_size()
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)
# normal case: some special tokens
if token_ids_1 is None:
return ([0] * len(token_ids_0)) + [1]
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present
def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
"""Do not add eos again if user already added it."""
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
warnings.warn(
f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
" eos tokens being added."
)
return token_ids
else:
return token_ids + [self.eos_token_id]
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
eos = [self.eos_token_id]
if token_ids_1 is None:
return len(token_ids_0 + eos) * [0]
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A sequence has the following format:
- single sequence: `X </s>`
- pair of sequences: `A </s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
token_ids_0 = self._add_eos_if_not_present(token_ids_0)
if token_ids_1 is None:
return token_ids_0
else:
token_ids_1 = self._add_eos_if_not_present(token_ids_1)
return token_ids_0 + token_ids_1
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__
def __setstate__(self, d):
self.__dict__ = d
# for backward compatibility
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)
def remove_punctuation(self, text: str) -> str:
return text.translate(str.maketrans("", "", string.punctuation))
# source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
def canonicalize_text(self, text, *, keep_punctuation_exact_string=None):
"""Returns canonicalized `text` (puncuation removed).
Args:
text (`str`):
String to be canonicalized.
keep_punctuation_exact_string (`str`, *optional*):
If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}'
(but will still remove '{' and '}' that appear separately).
"""
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
self.remove_punctuation(part)
for part in text.split(keep_punctuation_exact_string)
)
else:
text = self.remove_punctuation(text)
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def tokenize(
self, text: "TextInput", add_special_tokens=False, **kwargs
) -> List[str]:
"""
Converts a string to a list of tokens.
"""
tokens = super().tokenize(
SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs
)
if (
len(tokens) > 1
and tokens[0] == SPIECE_UNDERLINE
and tokens[1] in self.all_special_tokens
):
tokens = tokens[1:]
return tokens
@property
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length
def unk_token_length(self):
return len(self.sp_model.encode(str(self.unk_token)))
def _tokenize(self, text, **kwargs):
"""
Returns a tokenized string.
We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
SPIECE_UNDERLINE.
For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`.
Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
"""
text = self.canonicalize_text(text, keep_punctuation_exact_string=None)
tokens = self.sp_model.encode(text, out_type=str)
# 1. Encode string + prefix ex: "<unk> Hey"
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
return (
tokens[self.unk_token_length :]
if len(tokens) >= self.unk_token_length
else tokens
)
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.IdToPiece(index)
return token
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "")
+ VOCAB_FILES_NAMES["vocab_file"],
)
if os.path.abspath(self.vocab_file) != os.path.abspath(
out_vocab_file
) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)
decord==0.6.0
einops==0.8.1
huggingface_hub==0.29.1
matplotlib==3.7.0
numpy==1.24.4
opencv_python==4.7.0.72
pyarrow==11.0.0
PyYAML==6.0.2
Requests==2.32.3
safetensors==0.4.5
scipy==1.10.1
sentencepiece==0.1.99
torch==2.5.1
torchvision==0.20.1
transformers==4.49.0
#flash_attn==2.5.8
accelerate>=0.34.0
wandb
gradio
setuptools
wheel
ninja
bitsandbytes
xlsxwriter
triton ; sys_platform != 'win32'
triton-windows ; sys_platform == 'win32'
\ No newline at end of file
#!/bin/bash
set -x
# ============================================================
# SenseNova-SI-800K Training with BAGEL-7B-MoT
# ============================================================
#
# Step 1: Download the dataset from Hugging Face:
# pip install huggingface_hub
# python -c "
# from huggingface_hub import snapshot_download
# snapshot_download(
# repo_id='sensenova/SenseNova-SI-800K',
# repo_type='dataset',
# local_dir='data/SenseNova-SI-800K',
# )
# "
#
# Step 2: Download the BAGEL-7B-MoT model:
# python -c "
# from huggingface_hub import snapshot_download
# snapshot_download(
# repo_id='ByteDance-Seed/BAGEL-7B-MoT',
# local_dir='models/BAGEL-7B-MoT',
# )
# "
#
# Step 3: Run this script from the bagel-train directory.
# ============================================================
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
REPO_DIR="$(python -c 'import pathlib; print(pathlib.Path("'"${SCRIPT_DIR}"'").parent.resolve())')"
TRAINING_ROOT="$(python -c 'import pathlib; print(pathlib.Path("'"${REPO_DIR}"'").parent.resolve())')"
MODEL_PATH="${MODEL_PATH:-${TRAINING_ROOT}/pretrained_models/BAGEL-7B-MoT}"
NPROC_PER_NODE=${NPROC_PER_NODE:-8}
NNODES=${NNODES:-1}
RESULTS_DIR="${RESULTS_DIR:-${TRAINING_ROOT}/results/bagel/sensenova_si_800K}"
DATASET_CONFIG_FILE="${DATASET_CONFIG_FILE:-${REPO_DIR}/data/configs/sensenova_si_800K.yaml}"
export PYTHONPATH="${PYTHONPATH}:${REPO_DIR}"
export TRAINING_ROOT="${TRAINING_ROOT}"
cd "${REPO_DIR}"
torchrun \
--nproc_per_node ${NPROC_PER_NODE} \
--nnodes ${NNODES} \
--master-addr ${MASTER_ADDR:-127.0.0.1} \
--master-port ${MASTER_PORT:-29500} \
--node-rank ${RANK:-0} \
"${REPO_DIR}/train/pretrain_unified_navit.py" \
--dataset_config_file "${DATASET_CONFIG_FILE}" \
--model_name BAGEL \
--resume_from_hf ${MODEL_PATH} \
--layer_module Qwen2MoTDecoderLayer \
--max_latent_size 64 \
--resume-from ${MODEL_PATH} \
--auto_resume True \
--resume-model-only True \
--finetune-from-ema True \
--log_every 1 \
--lr 1e-6 \
--num_worker 0 \
--prefetch_factor 1 \
--expected_num_tokens 32768 \
--max_num_tokens 70000 \
--max_num_tokens_per_sample 50000 \
--freeze_vit True \
--visual_und True \
--visual_gen False \
--results_dir ${RESULTS_DIR} \
--checkpoint_dir ${RESULTS_DIR}/checkpoints \
--total_steps 2000 \
--save_every 1 \
--wandb_offline True \
--num_shard ${NPROC_PER_NODE}
#!/bin/bash
set -x
# ============================================================
# SenseNova-SI-8M Training with BAGEL-7B-MoT
# ============================================================
#
# Step 1: Download the dataset from Hugging Face:
# pip install huggingface_hub
# python -c "
# from huggingface_hub import snapshot_download
# snapshot_download(
# repo_id='sensenova/SenseNova-SI-8M',
# repo_type='dataset',
# local_dir='data/SenseNova-SI-8M',
# )
# "
#
# Step 2: Download the BAGEL-7B-MoT model:
# python -c "
# from huggingface_hub import snapshot_download
# snapshot_download(
# repo_id='ByteDance-Seed/BAGEL-7B-MoT',
# local_dir='models/BAGEL-7B-MoT',
# )
# "
#
# Step 3: Run this script from the bagel-train directory.
# ============================================================
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
REPO_DIR="$(python -c 'import pathlib; print(pathlib.Path("'"${SCRIPT_DIR}"'").parent.resolve())')"
TRAINING_ROOT="$(python -c 'import pathlib; print(pathlib.Path("'"${REPO_DIR}"'").parent.resolve())')"
MODEL_PATH="${MODEL_PATH:-${TRAINING_ROOT}/pretrained_models/BAGEL-7B-MoT}"
NPROC_PER_NODE=${NPROC_PER_NODE:-8}
NNODES=${NNODES:-1}
RESULTS_DIR="${RESULTS_DIR:-${TRAINING_ROOT}/results/bagel/sensenova_si_8M}"
DATASET_CONFIG_FILE="${DATASET_CONFIG_FILE:-${REPO_DIR}/data/configs/sensenova_si_8M.yaml}"
export PYTHONPATH="${PYTHONPATH}:${REPO_DIR}"
export TRAINING_ROOT="${TRAINING_ROOT}"
cd "${REPO_DIR}"
torchrun \
--nproc_per_node ${NPROC_PER_NODE} \
--nnodes ${NNODES} \
--master-addr ${MASTER_ADDR:-127.0.0.1} \
--master-port ${MASTER_PORT:-29500} \
--node-rank ${RANK:-0} \
"${REPO_DIR}/train/pretrain_unified_navit.py" \
--dataset_config_file "${DATASET_CONFIG_FILE}" \
--model_name BAGEL \
--resume_from_hf ${MODEL_PATH} \
--layer_module Qwen2MoTDecoderLayer \
--max_latent_size 64 \
--resume-from ${MODEL_PATH} \
--auto_resume True \
--resume-model-only True \
--finetune-from-ema True \
--log_every 1 \
--lr 1e-6 \
--num_worker 0 \
--prefetch_factor 1 \
--expected_num_tokens 32768 \
--max_num_tokens 70000 \
--max_num_tokens_per_sample 50000 \
--freeze_vit True \
--visual_und True \
--visual_gen False \
--results_dir ${RESULTS_DIR} \
--checkpoint_dir ${RESULTS_DIR}/checkpoints \
--total_steps 50000 \
--save_every 1 \
--wandb_offline True \
--num_shard ${NPROC_PER_NODE}
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import functools
import os
import torch
import torch.distributed as dist
import torch.distributed.fsdp._traversal_utils as traversal_utils
from modeling.bagel.modeling_utils import (
MLPconnector,
PositionEmbedding,
TimestepEmbedder,
)
from modeling.bagel.qwen2_navit import (
Qwen2DecoderLayer,
Qwen2MoEDecoderLayer,
Qwen2MoTDecoderLayer,
)
from modeling.bagel.siglip_navit import SiglipEncoderLayer, SiglipVisionTransformer
from safetensors.torch import load_file, save_file
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import (
BackwardPrefetch,
CPUOffload,
FullStateDictConfig,
MixedPrecision,
ShardingStrategy,
StateDictType,
)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
class FSDPConfig:
def __init__(
self,
sharding_strategy,
backward_prefetch,
cpu_offload,
num_replicate,
num_shard=8,
):
self.sharding_strategy = sharding_strategy
self.backward_prefetch = backward_prefetch
self.cpu_offload = cpu_offload
self.num_replicate = num_replicate
self.num_shard = num_shard
def fsdp_wrapper(original_model, fsdp_config, ignored_modules=[]):
if fsdp_config.sharding_strategy == "HYBRID_SHARD":
device_mesh = init_device_mesh(
"cuda",
mesh_shape=(fsdp_config.num_replicate, fsdp_config.num_shard),
mesh_dim_names=("replicate", "shard"),
)
else:
device_mesh = None
return FSDP(
original_model,
auto_wrap_policy=functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
Qwen2DecoderLayer,
Qwen2MoEDecoderLayer,
Qwen2MoTDecoderLayer,
SiglipEncoderLayer,
SiglipVisionTransformer,
MLPconnector,
TimestepEmbedder,
PositionEmbedding,
},
),
ignored_modules=ignored_modules,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
device_id=dist.get_rank() % torch.cuda.device_count(),
sharding_strategy=ShardingStrategy[fsdp_config.sharding_strategy],
backward_prefetch=BackwardPrefetch[fsdp_config.backward_prefetch],
cpu_offload=CPUOffload(offload_params=fsdp_config.cpu_offload),
device_mesh=device_mesh,
)
class FSDPCheckpoint:
@staticmethod
def fsdp_save_ckpt(
ckpt_dir,
train_steps,
model,
ema_model,
optimizer,
scheduler,
data_status,
logger,
fsdp_config,
):
save_path = os.path.join(ckpt_dir, f"{train_steps:07d}")
os.makedirs(save_path, exist_ok=True)
logger.info(f"Saving checkpoint to {save_path}.")
if ema_model is not None:
torch.cuda.empty_cache()
with FSDP.state_dict_type(
ema_model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
):
ema_state_dict = ema_model.state_dict()
if dist.get_rank() == 0:
save_file(
ema_state_dict, os.path.join(save_path, "ema.safetensors")
)
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
):
model_state_dict = model.state_dict()
if dist.get_rank() == 0:
save_file(
model_state_dict, os.path.join(save_path, "model.safetensors")
)
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
if fsdp_config.sharding_strategy == "FULL_SHARD":
shard_index = dist.get_rank()
total_shards = dist.get_world_size()
elif fsdp_config.sharding_strategy == "HYBRID_SHARD":
shard_index = dist.get_rank() % fsdp_config.num_shard
total_shards = fsdp_config.num_shard
else:
raise NotImplementedError
optimizer_save_path = os.path.join(
save_path, f"optimizer.{shard_index:05d}-of-{total_shards:05d}.pt"
)
if fsdp_config.sharding_strategy == "FULL_SHARD":
torch.save(optimizer.state_dict(), optimizer_save_path)
elif fsdp_config.sharding_strategy == "HYBRID_SHARD":
if dist.get_rank() < fsdp_config.num_shard:
torch.save(optimizer.state_dict(), optimizer_save_path)
else:
raise NotImplementedError
if dist.get_rank() == 0 and scheduler is not None:
torch.save(scheduler.state_dict(), os.path.join(save_path, "scheduler.pt"))
if dist.get_rank() == 0 and data_status is not None:
torch.save(data_status, os.path.join(save_path, "data_status.pt"))
dist.barrier()
return
@staticmethod
def try_load_ckpt(
resume_from, logger, model, ema_model=None, resume_from_ema=False
):
if resume_from is not None and os.path.exists(resume_from):
logger.info(f"Loading checkpoint from {resume_from}.")
if resume_from_ema:
model_state_dict_path = os.path.join(resume_from, f"ema.safetensors")
else:
model_state_dict_path = os.path.join(resume_from, f"model.safetensors")
model_state_dict = load_file(model_state_dict_path, device="cpu")
# NOTE position embeds are fixed sinusoidal embeddings, so we can just pop it off,
# which makes it easier to adapt to different resolutions.
model_state_dict.pop("vit_pos_embed.pos_embed", None)
model_state_dict.pop("latent_pos_embed.pos_embed", None)
msg = model.load_state_dict(model_state_dict, strict=False)
logger.info(msg)
del model_state_dict
if ema_model is not None:
ema_state_dict_path = os.path.join(resume_from, f"ema.safetensors")
if not os.path.exists(ema_state_dict_path):
logger.info(f"replicaing ema model from {model_state_dict_path}.")
ema_state_dict_path = model_state_dict_path
ema_state_dict = load_file(ema_state_dict_path, device="cpu")
# NOTE position embeds are fixed sinusoidal embeddings, so we can just pop it off,
# which makes it easier to adapt to different resolutions.
ema_state_dict.pop("latent_pos_embed.pos_embed", None)
ema_state_dict.pop("vit_pos_embed.pos_embed", None)
msg = ema_model.load_state_dict(ema_state_dict, strict=False)
logger.info(msg)
del ema_state_dict
else:
logger.info(f"Training from scratch.")
return model, ema_model
@staticmethod
def try_load_train_state(resume_from, optimizer, scheduler, fsdp_config):
if resume_from is not None and os.path.exists(resume_from):
if fsdp_config.sharding_strategy == "FULL_SHARD":
shard_index = dist.get_rank()
total_shards = dist.get_world_size()
elif fsdp_config.sharding_strategy == "HYBRID_SHARD":
shard_index = dist.get_rank() % fsdp_config.num_shard
total_shards = fsdp_config.num_shard
else:
raise NotImplementedError
optimizer_state_dict_path = os.path.join(
resume_from, f"optimizer.{shard_index:05d}-of-{total_shards:05d}.pt"
)
optimizer_state_dict = torch.load(
optimizer_state_dict_path, map_location="cpu", weights_only=True
)
optimizer.load_state_dict(optimizer_state_dict)
del optimizer_state_dict
scheduler_state_dict_path = os.path.join(resume_from, "scheduler.pt")
scheduler_state_dict = torch.load(
scheduler_state_dict_path, weights_only=True, map_location="cpu"
)
scheduler.load_state_dict(scheduler_state_dict)
del scheduler_state_dict
train_steps = int(os.path.basename(os.path.normpath(resume_from))) + 1
"""
data_status = [
{
dataset_name: {
worker_id: [parquet_idx, row_group_id, row_idx],
},
},
]
"""
data_status_path = os.path.join(resume_from, "data_status.pt")
if os.path.exists(data_status_path):
data_status = torch.load(
data_status_path, weights_only=True, map_location="cpu"
)
local_rank = dist.get_rank()
if local_rank < len(data_status):
data_status = data_status[local_rank]
else:
data_status = None
else:
data_status = None
else:
train_steps = 0
data_status = None
return optimizer, scheduler, train_steps, data_status
def grad_checkpoint_check_fn(module):
module_options = (
Qwen2DecoderLayer,
SiglipEncoderLayer,
MLPconnector,
Qwen2MoEDecoderLayer,
Qwen2MoTDecoderLayer,
)
return isinstance(module, module_options)
def fsdp_ema_setup(ema_model, fsdp_config, ignored_modules=[]):
for param in ema_model.parameters():
param.requires_grad = False
ema_model = fsdp_wrapper(ema_model, fsdp_config, ignored_modules=ignored_modules)
return ema_model
@torch.no_grad()
def fsdp_ema_update(ema_model, model, decay=0.9999):
ema_handles = traversal_utils._get_fsdp_handles(ema_model)
new_handles = traversal_utils._get_fsdp_handles(model)
assert len(ema_handles) == len(new_handles)
ema_params = []
new_params = []
for ema_handle, new_handle in zip(ema_handles, new_handles):
if ema_handle.flat_param is not None and new_handle.flat_param.requires_grad:
ema_params.append(ema_handle.flat_param.data)
new_params.append(
new_handle.flat_param.data.to(dtype=ema_handle.flat_param.dtype)
)
torch._foreach_mul_(ema_params, decay)
torch._foreach_add_(ema_params, new_params, alpha=1 - decay)
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import functools
import gc
import os
from copy import deepcopy
from dataclasses import dataclass, field
from time import time
from typing import Optional
import torch
import torch.distributed as dist
import wandb
import yaml
from data.data_utils import add_special_tokens
from data.dataset_base import DataConfig, PackedDataset, collate_wrapper
from modeling.autoencoder import load_ae
from modeling.bagel import (
Bagel,
BagelConfig,
Qwen2Config,
Qwen2ForCausalLM,
SiglipVisionConfig,
SiglipVisionModel,
)
from modeling.qwen2 import Qwen2Tokenizer
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)
from torch.utils.data import DataLoader
from transformers import HfArgumentParser, set_seed
from transformers.optimization import (
get_constant_schedule_with_warmup,
get_cosine_with_min_lr_schedule_with_warmup,
)
from train.fsdp_utils import (
FSDPCheckpoint,
FSDPConfig,
fsdp_ema_setup,
fsdp_ema_update,
fsdp_wrapper,
grad_checkpoint_check_fn,
)
from train.train_utils import create_logger, get_latest_ckpt
@dataclass
class ModelArguments:
model_name: str = field(default="BAGEL", metadata={"help": "Name of the model."})
llm_path: str = field(
default="hf/Qwen2.5-0.5B-Instruct/",
metadata={
"help": "Path or HuggingFace repo ID of the pretrained Qwen2-style language model."
},
)
llm_qk_norm: bool = field(
default=True,
metadata={"help": "Enable QK LayerNorm (qk_norm) inside the attention blocks."},
)
tie_word_embeddings: bool = field(
default=False,
metadata={"help": "Share input and output word embeddings (tied embeddings)."},
)
layer_module: str = field(
default="Qwen2MoTDecoderLayer",
metadata={"help": "Python class name of the decoder layer to instantiate."},
)
vae_path: str = field(
default="flux/vae/ae.safetensors",
metadata={
"help": "Path to the pretrained VAE checkpoint for latent-space image generation."
},
)
vit_path: str = field(
default="hf/siglip-so400m-14-980-flash-attn2-navit/",
metadata={
"help": "Path or repo ID of the SigLIP Vision Transformer used for image understanding."
},
)
max_latent_size: int = field(
default=32,
metadata={
"help": "Maximum latent grid size (patches per side) for the VAE latent tensor."
},
)
latent_patch_size: int = field(
default=2,
metadata={"help": "Spatial size (in VAE pixels) covered by each latent patch."},
)
vit_patch_size: int = field(
default=14,
metadata={"help": "Patch size (pixels) for the Vision Transformer encoder."},
)
vit_max_num_patch_per_side: int = field(
default=70,
metadata={
"help": "Maximum number of ViT patches along one image side after cropping / resize."
},
)
connector_act: str = field(
default="gelu_pytorch_tanh",
metadata={
"help": "Activation function used in the latent-to-text connector MLP."
},
)
interpolate_pos: bool = field(
default=False,
metadata={
"help": "Interpolate positional embeddings when image resolution differs from pre-training."
},
)
vit_select_layer: int = field(
default=-2,
metadata={
"help": "Which hidden layer of the ViT to take as the visual feature (negative = from the end)."
},
)
vit_rope: bool = field(
default=False, metadata={"help": "Replace ViT positional encodings with RoPE."}
)
text_cond_dropout_prob: float = field(
default=0.1,
metadata={"help": "Probability of dropping text embeddings during training."},
)
vae_cond_dropout_prob: float = field(
default=0.3,
metadata={"help": "Probability of dropping VAE latent inputs during training."},
)
vit_cond_dropout_prob: float = field(
default=0.3,
metadata={
"help": "Probability of dropping ViT visual features during training."
},
)
@dataclass
class DataArguments:
dataset_config_file: str = field(
default="data/configs/example.yaml",
metadata={
"help": "YAML file specifying dataset groups, weights, and preprocessing rules."
},
)
prefetch_factor: int = field(
default=2,
metadata={
"help": "How many batches each DataLoader worker pre-loads in advance."
},
)
num_workers: int = field(
default=4,
metadata={"help": "Number of background workers for the PyTorch DataLoader."},
)
max_num_tokens_per_sample: int = field(
default=16384,
metadata={
"help": "Maximum tokens allowed in one raw sample; longer samples are skipped."
},
)
max_num_tokens: int = field(
default=36864,
metadata={
"help": "Hard limit on tokens in a packed batch; flush if adding a sample would exceed it."
},
)
prefer_buffer_before: int = field(
default=16384,
metadata={
"help": "While batch length is below this, pop from the overflow buffer before new sampling."
},
)
max_buffer_size: int = field(
default=50,
metadata={
"help": "Maximum number of oversized samples kept in the overflow buffer."
},
)
data_seed: int = field(
default=42,
metadata={
"help": "Seed used when shuffling / sampling data shards to ensure reproducibility."
},
)
@dataclass
class TrainingArguments:
# --- modality switches ---
visual_gen: bool = field(
default=True, metadata={"help": "Train image generation branch."}
)
visual_und: bool = field(
default=True, metadata={"help": "Train image understanding branch."}
)
# --- bookkeeping & logging ---
results_dir: str = field(
default="results", metadata={"help": "Root directory for logs."}
)
checkpoint_dir: str = field(
default="results/checkpoints",
metadata={"help": "Root directory for model checkpoints."},
)
wandb_project: str = field(
default="bagel", metadata={"help": "Weights & Biases project name."}
)
wandb_name: str = field(
default="run",
metadata={"help": "Name shown in the Weights & Biases UI for this run."},
)
wandb_runid: str = field(
default="0",
metadata={
"help": "Unique identifier to resume a previous W&B run, if desired."
},
)
wandb_resume: str = field(
default="allow",
metadata={"help": "W&B resume mode: 'allow', 'must', or 'never'."},
)
wandb_offline: bool = field(
default=False,
metadata={"help": "Run W&B in offline mode (logs locally, sync later)."},
)
# --- reproducibility & resume ---
global_seed: int = field(
default=4396,
metadata={"help": "Base random seed; actual seed is offset by rank for DDP."},
)
auto_resume: bool = field(
default=False,
metadata={
"help": "Automatically pick up the latest checkpoint found in checkpoint_dir."
},
)
resume_from: str = field(
default=None,
metadata={
"help": "Explicit checkpoint path to resume from (overrides auto_resume)."
},
)
resume_from_hf: str = field(
default=None,
metadata={
"help": "Path of the pretrained BAGEL models, including llm, vit, and vae."
},
)
resume_model_only: bool = field(
default=False,
metadata={
"help": "Load only model weights, ignoring optimizer/scheduler states."
},
)
finetune_from_ema: bool = field(
default=False,
metadata={
"help": "When resume_model_only=True, load the EMA (exponential moving average) weights instead of raw weights."
},
)
# --- reporting frequency ---
log_every: int = field(
default=10, metadata={"help": "Print / log every N training steps."}
)
save_every: int = field(
default=2000, metadata={"help": "Save a checkpoint every N training steps."}
)
total_steps: int = field(
default=500_000,
metadata={"help": "Total number of optimizer steps to train for."},
)
# --- optimization & scheduler ---
warmup_steps: int = field(
default=2000,
metadata={"help": "Linear warm-up steps before applying the main LR schedule."},
)
lr_scheduler: str = field(
default="constant",
metadata={"help": "Type of LR schedule: 'constant' or 'cosine'."},
)
lr: float = field(
default=1e-4, metadata={"help": "Peak learning rate after warm-up."}
)
min_lr: float = field(
default=1e-7,
metadata={
"help": "Minimum learning rate for cosine schedule (ignored for constant)."
},
)
beta1: float = field(default=0.9, metadata={"help": "AdamW β₁ coefficient."})
beta2: float = field(default=0.95, metadata={"help": "AdamW β₂ coefficient."})
eps: float = field(
default=1e-15, metadata={"help": "AdamW ε for numerical stability."}
)
ema: float = field(
default=0.9999,
metadata={
"help": "Decay rate for the exponential moving average of model weights."
},
)
max_grad_norm: float = field(
default=1.0, metadata={"help": "Gradient clipping threshold (L2 norm)."}
)
timestep_shift: float = field(
default=1.0,
metadata={
"help": "Shift applied to diffusion timestep indices (for latent prediction)."
},
)
mse_weight: float = field(
default=1.0,
metadata={"help": "Scaling factor for the image-reconstruction MSE loss term."},
)
ce_weight: float = field(
default=1.0,
metadata={"help": "Scaling factor for the language cross-entropy loss term."},
)
ce_loss_reweighting: bool = field(
default=False,
metadata={
"help": "Reweight CE loss by token importance (provided via ce_loss_weights)."
},
)
expected_num_tokens: int = field(
default=32768,
metadata={
"help": "Soft target token count; yield the batch once it reaches or exceeds this size."
},
)
# --- distributed training / FSDP ---
num_replicate: int = field(
default=1,
metadata={
"help": "Number of model replicas per GPU rank for tensor parallelism."
},
)
num_shard: int = field(
default=8,
metadata={"help": "Number of parameter shards when using FSDP HYBRID_SHARD."},
)
sharding_strategy: str = field(
default="HYBRID_SHARD",
metadata={
"help": "FSDP sharding strategy: FULL_SHARD, SHARD_GRAD_OP, HYBRID_SHARD, etc."
},
)
backward_prefetch: str = field(
default="BACKWARD_PRE",
metadata={
"help": "FSDP backward prefetch strategy (BACKWARD_PRE or NO_PREFETCH)."
},
)
cpu_offload: bool = field(
default=False, metadata={"help": "Enable FSDP parameter offload to CPU."}
)
# --- module freezing ---
freeze_llm: bool = field(
default=False,
metadata={"help": "Keep language-model weights fixed (no gradient updates)."},
)
freeze_vit: bool = field(
default=False, metadata={"help": "Keep ViT weights fixed during training."}
)
freeze_vae: bool = field(
default=True,
metadata={
"help": "Keep VAE weights fixed; only predict latents, don't fine-tune encoder/decoder."
},
)
freeze_und: bool = field(
default=False,
metadata={"help": "Freeze the visual understanding connector layers."},
)
copy_init_moe: bool = field(
default=True,
metadata={
"help": "Duplicate initial MoE experts so each has identical initialisation."
},
)
use_flex: bool = field(
default=False,
metadata={
"help": "Enable FLEX (flash-ext friendly) packing algorithm for sequence data."
},
)
gradient_accumulation_steps: int = field(
default=1,
metadata={
"help": "Number of steps to accumulate gradients before performing optimization"
},
)
def main():
assert torch.cuda.is_available()
dist.init_process_group("nccl")
device = dist.get_rank() % torch.cuda.device_count()
torch.cuda.set_device(device)
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Setup logging:
if dist.get_rank() == 0:
os.makedirs(training_args.results_dir, exist_ok=True)
os.makedirs(training_args.checkpoint_dir, exist_ok=True)
logger = create_logger(training_args.results_dir, dist.get_rank())
wandb.init(
project=training_args.wandb_project,
id=f"{training_args.wandb_name}-run{training_args.wandb_runid}",
name=training_args.wandb_name,
resume=training_args.wandb_resume,
mode="offline" if training_args.wandb_offline else "online",
)
wandb.config.update(training_args)
wandb.config.update(model_args)
wandb.config.update(data_args)
else:
logger = create_logger(None, dist.get_rank())
dist.barrier()
logger.info(f"Training arguments {training_args}")
logger.info(f"Model arguments {model_args}")
logger.info(f"Data arguments {data_args}")
# prepare auto resume logic:
if training_args.auto_resume:
resume_from = get_latest_ckpt(training_args.checkpoint_dir)
if resume_from is None:
resume_from = training_args.resume_from
resume_model_only = training_args.resume_model_only
if resume_model_only:
finetune_from_ema = training_args.finetune_from_ema
else:
finetune_from_ema = False
else:
resume_model_only = False
finetune_from_ema = False
else:
resume_from = training_args.resume_from
resume_model_only = training_args.resume_model_only
if resume_model_only:
finetune_from_ema = training_args.finetune_from_ema
else:
finetune_from_ema = False
# Set seed:
seed = training_args.global_seed * dist.get_world_size() + dist.get_rank()
set_seed(seed)
# Setup model:
if training_args.resume_from_hf is not None:
llm_config = Qwen2Config.from_json_file(
os.path.join(training_args.resume_from_hf, "llm_config.json")
)
else:
llm_config = Qwen2Config.from_pretrained(model_args.llm_path)
llm_config.layer_module = model_args.layer_module
llm_config.qk_norm = model_args.llm_qk_norm
llm_config.tie_word_embeddings = model_args.tie_word_embeddings
llm_config.freeze_und = training_args.freeze_und
if training_args.resume_from_hf is not None:
language_model = Qwen2ForCausalLM(llm_config)
else:
language_model = Qwen2ForCausalLM.from_pretrained(
model_args.llm_path, config=llm_config
)
if training_args.copy_init_moe:
language_model.init_moe()
if training_args.visual_und:
if training_args.resume_from_hf is not None:
vit_config = SiglipVisionConfig.from_json_file(
os.path.join(training_args.resume_from_hf, "vit_config.json")
)
else:
vit_config = SiglipVisionConfig.from_pretrained(model_args.vit_path)
vit_config.num_hidden_layers = (
vit_config.num_hidden_layers + 1 + model_args.vit_select_layer
)
vit_config.rope = model_args.vit_rope
if training_args.resume_from_hf is not None:
vit_model = SiglipVisionModel(vit_config)
else:
vit_model = SiglipVisionModel.from_pretrained(
model_args.vit_path, config=vit_config
)
if training_args.visual_gen:
vae_model, vae_config = load_ae(
local_path=os.path.join(training_args.resume_from_hf, "ae.safetensors")
if training_args.resume_from_hf is not None
else model_args.vae_path
)
config = BagelConfig(
visual_gen=training_args.visual_gen,
visual_und=training_args.visual_und,
llm_config=llm_config,
vit_config=vit_config if training_args.visual_und else None,
vae_config=vae_config if training_args.visual_gen else None,
latent_patch_size=model_args.latent_patch_size,
max_latent_size=model_args.max_latent_size,
vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side,
connector_act=model_args.connector_act,
interpolate_pos=model_args.interpolate_pos,
timestep_shift=training_args.timestep_shift,
)
model = Bagel(
language_model, vit_model if training_args.visual_und else None, config
)
if training_args.visual_und:
model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config)
# Setup tokenizer for model:
tokenizer = Qwen2Tokenizer.from_pretrained(
training_args.resume_from_hf
if training_args.resume_from_hf
else model_args.llm_path
)
tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer)
if num_new_tokens > 0:
model.language_model.resize_token_embeddings(len(tokenizer))
model.config.llm_config.vocab_size = len(tokenizer)
model.language_model.config.vocab_size = len(tokenizer)
# maybe freeze something:
if training_args.freeze_vae and training_args.visual_gen:
for param in vae_model.parameters():
param.requires_grad = False
if training_args.freeze_llm:
model.language_model.eval()
for param in model.language_model.parameters():
param.requires_grad = False
if training_args.freeze_vit and training_args.visual_und:
model.vit_model.eval()
for param in model.vit_model.parameters():
param.requires_grad = False
# Setup FSDP and load pretrained model:
fsdp_config = FSDPConfig(
sharding_strategy=training_args.sharding_strategy,
backward_prefetch=training_args.backward_prefetch,
cpu_offload=training_args.cpu_offload,
num_replicate=training_args.num_replicate,
num_shard=training_args.num_shard,
)
ema_model = deepcopy(model)
model, ema_model = FSDPCheckpoint.try_load_ckpt(
resume_from, logger, model, ema_model, resume_from_ema=finetune_from_ema
)
ema_model = fsdp_ema_setup(ema_model, fsdp_config)
fsdp_model = fsdp_wrapper(model, fsdp_config)
apply_activation_checkpointing(
fsdp_model,
checkpoint_wrapper_fn=functools.partial(
checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT
),
check_fn=grad_checkpoint_check_fn,
)
if dist.get_rank() == 0:
print(fsdp_model)
for name, param in model.named_parameters():
print(name, param.requires_grad)
# Setup optimizer and scheduler
optimizer = torch.optim.AdamW(
fsdp_model.parameters(),
lr=training_args.lr,
betas=(training_args.beta1, training_args.beta2),
eps=training_args.eps,
weight_decay=0,
)
if training_args.lr_scheduler == "cosine":
scheduler = get_cosine_with_min_lr_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=training_args.total_steps,
min_lr=training_args.min_lr,
)
elif training_args.lr_scheduler == "constant":
scheduler = get_constant_schedule_with_warmup(
optimizer=optimizer, num_warmup_steps=training_args.warmup_steps
)
else:
raise ValueError
# maybe resume optimizer, scheduler, and train_steps
if resume_model_only:
train_step = 0
data_status = None
else:
optimizer, scheduler, train_step, data_status = (
FSDPCheckpoint.try_load_train_state(
resume_from,
optimizer,
scheduler,
fsdp_config,
)
)
# Setup packed dataloader
with open(data_args.dataset_config_file, "r") as stream:
dataset_meta = yaml.safe_load(stream)
dataset_config = DataConfig(grouped_datasets=dataset_meta)
if training_args.visual_und:
dataset_config.vit_patch_size = model_args.vit_patch_size
dataset_config.max_num_patch_per_side = model_args.vit_max_num_patch_per_side
if training_args.visual_gen:
vae_image_downsample = model_args.latent_patch_size * vae_config.downsample
dataset_config.vae_image_downsample = vae_image_downsample
dataset_config.max_latent_size = model_args.max_latent_size
dataset_config.text_cond_dropout_prob = model_args.text_cond_dropout_prob
dataset_config.vae_cond_dropout_prob = model_args.vae_cond_dropout_prob
dataset_config.vit_cond_dropout_prob = model_args.vit_cond_dropout_prob
collate_fn = collate_wrapper()
train_dataset = PackedDataset(
dataset_config,
tokenizer=tokenizer,
special_tokens=new_token_ids,
local_rank=dist.get_rank(),
world_size=dist.get_world_size(),
num_workers=data_args.num_workers,
expected_num_tokens=training_args.expected_num_tokens,
max_num_tokens_per_sample=data_args.max_num_tokens_per_sample,
max_num_tokens=data_args.max_num_tokens,
max_buffer_size=data_args.max_buffer_size,
prefer_buffer_before=data_args.prefer_buffer_before,
interpolate_pos=model_args.interpolate_pos,
use_flex=training_args.use_flex,
data_status=data_status,
)
train_dataset.set_epoch(data_args.data_seed)
if data_args.num_workers > 0:
train_loader = DataLoader(
train_dataset,
batch_size=1, # batch size is 1 packed dataset
num_workers=data_args.num_workers,
pin_memory=True,
collate_fn=collate_fn,
drop_last=True,
prefetch_factor=data_args.prefetch_factor,
)
else:
train_loader = DataLoader(
train_dataset,
batch_size=1, # batch size is 1 packed dataset
num_workers=data_args.num_workers,
pin_memory=True,
collate_fn=collate_fn,
drop_last=True,
)
# Prepare models for training:
if training_args.visual_gen:
vae_model.to(device).eval()
fsdp_model.train()
ema_model.eval()
# train loop
start_time = time()
logger.info(
f"Training for {training_args.total_steps} steps, starting at {train_step}..."
)
for curr_step, data in enumerate(train_loader, start=train_step):
# gc.collect()
data = data.cuda(device).to_dict()
data_indexes = data.pop("batch_data_indexes", None)
ce_loss_weights = data.pop("ce_loss_weights", None)
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
if training_args.visual_gen:
with torch.no_grad():
data["padded_latent"] = vae_model.encode(data.pop("padded_images"))
try:
loss_dict = fsdp_model(**data)
except Exception as e:
raise e
loss = 0
ce = loss_dict["ce"]
if ce is not None:
total_ce_tokens = torch.tensor(len(data["ce_loss_indexes"]), device=device)
dist.all_reduce(total_ce_tokens, op=dist.ReduceOp.SUM)
if training_args.ce_loss_reweighting:
ce = ce * ce_loss_weights
total_ce_loss_weights = ce_loss_weights.sum()
dist.all_reduce(total_ce_loss_weights, op=dist.ReduceOp.SUM)
ce = ce.sum() * dist.get_world_size() / total_ce_loss_weights
else:
ce = ce.sum() * dist.get_world_size() / total_ce_tokens
loss_dict["ce"] = ce.detach()
loss = loss + ce * training_args.ce_weight
else:
assert not training_args.visual_und
loss_dict["ce"] = torch.tensor(0, device=device)
total_ce_tokens = torch.tensor(0, device=device)
if training_args.visual_gen:
mse = loss_dict["mse"]
total_mse_tokens = torch.tensor(
len(data["mse_loss_indexes"]), device=device
)
dist.all_reduce(total_mse_tokens, op=dist.ReduceOp.SUM)
mse = mse.mean(dim=-1).sum() * dist.get_world_size() / total_mse_tokens
loss_dict["mse"] = mse.detach()
loss = loss + mse * training_args.mse_weight
else:
assert not training_args.visual_gen
loss_dict["mse"] = torch.tensor(0, device=device)
total_mse_tokens = torch.tensor(0, device=device)
# free_mem_before, total_mem_before = torch.cuda.mem_get_info()
for k in data.keys():
if isinstance(data[k], torch.Tensor):
# del data[k]
data[k] = data[k].cpu() # or del data[k]
torch.cuda.empty_cache()
# print("!")
# free_mem_after, total_mem_after = torch.cuda.mem_get_info()
# print(f"Freed memory: {(free_mem_after - free_mem_before) / 1024 ** 2:.2f} MB")
# # optimizer.zero_grad()
# torch.cuda.empty_cache()
loss.backward()
# Only update every gradient_accumulation_steps
if curr_step > 0 and curr_step % training_args.gradient_accumulation_steps == 0:
total_norm = fsdp_model.clip_grad_norm_(training_args.max_grad_norm)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
fsdp_ema_update(ema_model, fsdp_model, decay=training_args.ema)
# Log loss values:
if curr_step % training_args.log_every == 0:
total_samples = torch.tensor(len(data["sample_lens"]), device=device)
dist.all_reduce(total_samples, op=dist.ReduceOp.SUM)
# Measure training speed:
torch.cuda.synchronize()
end_time = time()
steps_per_sec = training_args.log_every / (end_time - start_time)
message = f"(step={curr_step:07d}) "
wandb_log = {}
for key, value in loss_dict.items():
# Reduce loss history over all processes:
avg_loss = torch.tensor(value.item(), device=device)
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
avg_loss = avg_loss.item() / dist.get_world_size()
message += f"Train Loss {key}: {avg_loss:.4f}, "
wandb_log[key] = avg_loss
message += f"Train Steps/Sec: {steps_per_sec:.2f}, "
logger.info(message)
wandb_log["lr"] = optimizer.param_groups[0]["lr"]
wandb_log["total_mse_tokens"] = total_mse_tokens.item()
wandb_log["total_ce_tokens"] = total_ce_tokens.item()
wandb_log["total_norm"] = total_norm.item()
wandb_log["total_samples"] = total_samples.item()
mem_allocated = torch.tensor(
torch.cuda.max_memory_allocated() / 1024**2, device=device
)
dist.all_reduce(mem_allocated, op=dist.ReduceOp.MAX)
wandb_log["mem_allocated"] = mem_allocated
mem_cache = torch.tensor(
torch.cuda.max_memory_reserved() / 1024**2, device=device
)
dist.all_reduce(mem_cache, op=dist.ReduceOp.MAX)
wandb_log["mem_cache"] = mem_cache
if dist.get_rank() == 0:
wandb.log(wandb_log, step=curr_step)
start_time = time()
if data_status is None:
data_status = {}
for item in data_indexes:
if item["dataset_name"] not in data_status.keys():
data_status[item["dataset_name"]] = {}
data_status[item["dataset_name"]][item["worker_id"]] = item["data_indexes"]
if curr_step > 0 and curr_step % training_args.save_every == 0:
# if dist.get_rank() == 0:
# gather_list = [None] * dist.get_world_size()
# else:
# gather_list = None
# dist.gather_object(data_status, gather_list, dst=0)
FSDPCheckpoint.fsdp_save_ckpt(
ckpt_dir=training_args.checkpoint_dir,
train_steps=curr_step,
model=fsdp_model,
ema_model=None,
optimizer=optimizer,
scheduler=scheduler,
logger=logger,
fsdp_config=fsdp_config,
data_status=None,
)
# comment out as an alternative to save the ema model in pt format
# ema_state_dict = {}
# for name, param in ema_model.named_parameters():
# ema_state_dict[name] = param.detach().cpu()
# torch.save(
# ema_state_dict,
# os.path.join(training_args.checkpoint_dir, f"{curr_step:07d}", "ema_standard.pt")
# )
torch.cuda.empty_cache()
logger.info("Done!")
if dist.get_rank() == 0:
wandb.finish()
dist.destroy_process_group()
if __name__ == "__main__":
main()
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import logging
import os
def create_logger(logging_dir, rank, filename="log"):
"""
Create a logger that writes to a log file and stdout.
"""
if rank == 0 and logging_dir is not None: # real logger
logging.basicConfig(
level=logging.INFO,
format="[\033[34m%(asctime)s\033[0m] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[
logging.StreamHandler(),
logging.FileHandler(f"{logging_dir}/{filename}.txt"),
],
)
logger = logging.getLogger(__name__)
else: # dummy logger (does nothing)
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
return logger
def get_latest_ckpt(checkpoint_dir):
step_dirs = [
d
for d in os.listdir(checkpoint_dir)
if os.path.isdir(os.path.join(checkpoint_dir, d))
]
if len(step_dirs) == 0:
return None
step_dirs = sorted(step_dirs, key=lambda x: int(x))
latest_step_dir = os.path.join(checkpoint_dir, step_dirs[-1])
return latest_step_dir
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment