Commit 441126fe authored by luopl's avatar luopl
Browse files

"Initial commit"

parents
Pipeline #3069 canceled with stages
icon.png

50.3 KB

# 模型唯一标识
modelCode=1855
# 模型名称
modelName=Ovis2.5_pytorch
# 模型描述
modelDescription=Ovis2.5专为原生分辨率视觉感知和增强的多模态推理而设计,在图像推理、视频理解和接地基准测试中表现出领先的性能,展现了强大的通用多模态能力。
# 运行过程
processType=推理
# 算法类别
appCategory=多模态
# 框架类型
frameType=pytorch
# 加速卡类型
accelerateType=BW1000
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# for torch==2.4.0
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="torch.utils.checkpoint", lineno=1399)
from transformers import AutoConfig, AutoModel
from .vit.modeling_siglip2_navit import Siglip2NavitModel
from .vit.configuration_siglip2_navit import Siglip2NavitConfig
AutoConfig.register('siglip2_navit', Siglip2NavitConfig)
AutoModel.register(Siglip2NavitConfig, Siglip2NavitModel)
from typing import Union, Optional
from transformers import PretrainedConfig, Qwen3Config
from . import Siglip2NavitConfig
class OvisConfig(PretrainedConfig):
model_type = "ovis"
sub_configs = dict(llm_config=Qwen3Config, vit_config=Siglip2NavitConfig)
def __init__(self,
llm_config: Optional[Union[Qwen3Config, dict]] = None,
vit_config: Optional[Union[Siglip2NavitConfig, dict]] = None,
visual_vocab_size=65536,
hidden_size=None,
conversation_formatter_class=None,
**kwargs
):
super().__init__(**kwargs)
if isinstance(llm_config, dict):
llm_config = Qwen3Config(**llm_config)
self.llm_config = llm_config
if isinstance(vit_config, dict):
vit_config = Siglip2NavitConfig(**vit_config)
self.vit_config = vit_config
self.visual_vocab_size = visual_vocab_size
self.hidden_size = hidden_size
self.conversation_formatter_class = conversation_formatter_class
if kwargs.get('attn_implementation'):
self.llm_config._attn_implementation = kwargs['attn_implementation']
self.vit_config._attn_implementation = kwargs['attn_implementation']
import copy
from abc import ABC, abstractmethod
from typing import List, Dict
from ovis.util.constants import IMAGE_TOKEN_ID, IGNORE_ID, IMAGE_TOKEN, VIDEO_TOKEN_ID, VIDEO_TOKEN
class ConversationFormatter(ABC):
support_tokenizer_types = None
def __init__(self, tokenizer):
tokenizer_type = type(tokenizer).__name__
assert tokenizer_type in self.support_tokenizer_types, \
f'Invalid tokenizer type, expected one from `{self.support_tokenizer_types}`, but got `{tokenizer_type}`'
self.tokenizer = tokenizer
self.image_token = IMAGE_TOKEN
self.image_token_id = IMAGE_TOKEN_ID
self.ignore_id = IGNORE_ID
self.im_end = None
self.video_token = VIDEO_TOKEN
self.video_token_id = VIDEO_TOKEN_ID
def _tokenize_with_image_symbol(self, text):
if text.find(self.video_token) != -1:
token = self.video_token
token_id = self.video_token_id
else:
token = self.image_token
token_id = self.image_token_id
text_chunks = [self.tokenizer(chunk, add_special_tokens=False).input_ids for chunk in
text.split(token)]
token_ids = []
num_chuck = len(text_chunks)
for i, chunk in enumerate(text_chunks):
token_ids.extend(chunk)
if i < num_chuck - 1:
token_ids.append(token_id)
return token_ids
@abstractmethod
def format(self, conversations: List[Dict], generation_preface=None, enable_thinking=False):
pass
@abstractmethod
def format_query(self, query, generation_preface=""):
pass
class Qwen3ConversationFormatter(ConversationFormatter):
support_tokenizer_types = ['QWenTokenizer', 'Qwen2TokenizerFast']
def __init__(self, tokenizer):
super().__init__(tokenizer)
self.from2role = {
"system": "<|im_start|>system\n",
"human": "<|im_start|>user\n",
"gpt": "<|im_start|>assistant\n",
"ignored_gpt": "<|im_start|>assistant\n",
}
self.im_end = "<|im_end|>\n"
self.empty_think = "<think>\n\n</think>\n\n"
self.gpt_token_nums = None
def _initialize_gpt_token_nums(self) -> Dict[str, int]:
think_prefix = self.from2role["gpt"]
think_num = len(
self.tokenizer(think_prefix, add_special_tokens=False).input_ids
)
no_think_prefix = self.from2role["gpt"] + self.empty_think
no_think_num = len(
self.tokenizer(no_think_prefix, add_special_tokens=False).input_ids
)
return {'think': think_num, 'no_think': no_think_num}
# enable_thinking is deprecated
def format(self, conversations: List[Dict], generation_preface=None, enable_thinking=False):
conversations = copy.deepcopy(conversations)
if generation_preface is not None:
conversations.append({
"from": "gpt",
"value": generation_preface
})
prompt = ""
input_ids = []
labels = []
num_conversation = len(conversations)
for i, conversation in enumerate(conversations):
frm = conversation["from"]
role = self.from2role[frm]
message = conversation["value"]
has_thinking = '<think>' in message and '</think>' in message
if frm == 'gpt' and not has_thinking and generation_preface is None:
text = role + self.empty_think + message
else:
text = role + message
if self.gpt_token_nums is None:
self.gpt_token_nums = self._initialize_gpt_token_nums()
gpt_token_num = self.gpt_token_nums['think'] if has_thinking else self.gpt_token_nums['no_think']
if i < num_conversation - 1 or generation_preface is None:
text += self.im_end
prompt += text
token_ids = self._tokenize_with_image_symbol(text)
input_ids.extend(token_ids)
label_ids = [self.ignore_id] * len(token_ids)
if frm == "gpt" and generation_preface is None:
# learning `\n` following `im_end` is meaningless, so the last `\n` token is ignored in label
label_ids[gpt_token_num:-1] = token_ids[gpt_token_num:-1]
labels.extend(label_ids)
assert self._tokenize_with_image_symbol(prompt) == input_ids
assert len(input_ids) == len(labels)
if conversations[-1]['from'] == "gpt" and generation_preface is None:
# remove the last `\n` following `im_end` in input_ids
input_ids.pop()
labels.pop()
return prompt, input_ids, labels
def format_query(self, query, generation_preface="", enable_thinking=False):
prompt, input_ids, _ = self.format([{
"from": "human",
"value": query
}], generation_preface=generation_preface, enable_thinking=enable_thinking)
return prompt, input_ids
import logging
import math
import os
from datetime import datetime
from importlib import import_module
from typing import List, Union, Callable, Optional, Dict, Tuple
import PIL.Image
import deepspeed
import numpy as np
import torch
from torch import Tensor
from torch.nn import init
from transformers import PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoImageProcessor
from transformers.generation.utils import GenerateOutput
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled, deepspeed_config
from ovis.model.configuration_ovis import OvisConfig
from ovis.model.conversation_formatter import ConversationFormatter
from ovis.util.constants import IGNORE_ID, BEGIN_LINE, END_LINE, VISUAL_ATOM_ID, INDICATOR_IDS, \
IMAGE_TOKEN_ID, VIDEO_TOKEN_ID
from ovis.util.utils import rank0_print
class VisualEmbedding(torch.nn.Embedding):
def forward(self, visual_tokens: Tensor) -> Tensor:
if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
return super().forward(visual_tokens)
return torch.matmul(visual_tokens, self.weight)
def reset_parameters(self, mean=0., std=1.) -> None:
init.normal_(self.weight, mean=mean, std=std)
self._fill_padding_idx_with_zero()
class VisualTokenizer(torch.nn.Module):
def __init__(self, vit, visual_vocab_size, image_processor_name_or_path, *args, **kwargs):
super().__init__(*args, **kwargs)
self.vit = vit
self.image_processor = AutoImageProcessor.from_pretrained(image_processor_name_or_path, do_center_crop=False)
head_dim = visual_vocab_size - len(INDICATOR_IDS)
self.head = torch.nn.Sequential(
torch.nn.Linear(self.vit.config.hidden_size * self.vit.config.hidden_stride ** 2, head_dim, bias=False),
torch.nn.LayerNorm(head_dim)
)
def _get_last_block(self):
return self.vit._get_block(-1)
def _encode(self, pixel_values, grid_thws):
output = self.vit(pixel_values, grid_thws, output_hidden_states=True, return_dict=True)
features = output.hidden_states[-1]
seq_len, _ = features.shape
features = features.reshape(seq_len // (self.vit.config.hidden_stride ** 2), -1)
return features
# Adapted from qwen2_vl
@staticmethod
def smart_resize(
height: int, width: int, factor: int = 28, min_pixels: int = 448 * 448, max_pixels: int = 1344 * 1792
):
"""Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if height < factor or width < factor:
logging.warning(
f"Resizing image from ({height=}, {width=}) because a dimension is smaller than {factor}."
)
if height < width:
width = round(factor / height * width)
height = factor
else:
height = round(factor / width * height)
width = factor
elif max(height, width) / min(height, width) > 200:
logging.warning(
f"Resizing image from ({height=}, {width=}) because the aspect ratio is larger than 200"
)
if height > width:
height = 200 * width
else:
width = 200 * height
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = math.floor(height / beta / factor) * factor
w_bar = math.floor(width / beta / factor) * factor
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
def preprocess(
self,
image: Optional[PIL.Image.Image] = None,
video: Optional[List[PIL.Image.Image]] = None,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None
):
patch_size = self.vit.config.patch_size
temporal_patch_size = self.vit.config.temporal_patch_size
hidden_stride = self.vit.config.hidden_stride
assert (image is None) ^ (video is None), "Invalid input: expect either image or video"
if image is not None:
images = [image]
else:
images = video
images = [image.convert("RGB") if image.mode != 'RGB' else image for image in images]
width, height = images[0].size
processed_images = []
for image in images:
resized_height, resized_width = self.smart_resize(
height,
width,
factor=patch_size * hidden_stride,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
new_size = dict(height=resized_height, width=resized_width)
new_image = self.image_processor.preprocess(image, size=new_size, return_tensors="np")['pixel_values'][0]
processed_images.append(new_image)
patches = np.array(processed_images)
if patches.shape[0] % temporal_patch_size != 0:
repeats = np.repeat(patches[-1][np.newaxis], temporal_patch_size - 1, axis=0)
patches = np.concatenate([patches, repeats], axis=0)
channel = patches.shape[1]
grid_t = patches.shape[0] // temporal_patch_size
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
grid_thw = torch.tensor([[grid_t, grid_h, grid_w]])
patches = patches.reshape(
grid_t,
temporal_patch_size,
channel,
grid_h // hidden_stride,
hidden_stride,
patch_size,
grid_w // hidden_stride,
hidden_stride,
patch_size,
)
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size
)
flatten_patches = torch.tensor(flatten_patches)
return flatten_patches, grid_thw
def get_dummy_visual_inputs(self):
pixel_values = torch.zeros((2 * 2, 3 * self.vit.config.patch_size ** 2), dtype=self.vit.dtype,
device=self.vit.device)
grid_thws = torch.tensor([[1, 2, 2]], dtype=torch.long, device=self.vit.device)
return pixel_values, grid_thws
def forward(
self, pixel_values, grid_thws
) -> torch.Tensor: # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize]
features = self._encode(pixel_values, grid_thws)
logits = self.head(features)
tokens = torch.softmax(logits, dim=-1, dtype=torch.float32).to(logits.dtype)
# tokens' shape is [#Token, VocabSize-2], so padding with [#Token, 2], after
# which, tokens' shape should become [#Token, VocabSize];
token_len, _ = tokens.shape
padding_tensor = torch.zeros(size=(token_len, len(INDICATOR_IDS)),
dtype=tokens.dtype,
device=tokens.device,
layout=tokens.layout,
requires_grad=False)
tokens = torch.cat((tokens, padding_tensor), dim=1)
return tokens
def get_monitor_tensors(self):
monitor_tensors = dict(
vit_bottom=self.vit._get_attn_weight(0),
vit_top=self.vit._get_attn_weight(-1),
head=self.head[0].weight,
pos_embed=self.vit._get_pose_embed()
)
return monitor_tensors
class OvisPreTrainedModel(PreTrainedModel):
config_class = OvisConfig
base_model_prefix = "ovis"
class Ovis(OvisPreTrainedModel):
_supports_flash_attn_2 = True
def __init__(self, config: OvisConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
if kwargs.get('train_from_scratch'):
self.llm = kwargs['llm']
self.text_tokenizer = kwargs['text_tokenizer']
self.visual_tokenizer = kwargs['visual_tokenizer']
else:
self.llm = AutoModelForCausalLM.from_config(self.config.llm_config)
assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch"
self.text_tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
self.visual_tokenizer = VisualTokenizer(vit=AutoModel.from_config(self.config.vit_config),
visual_vocab_size=self.config.visual_vocab_size,
image_processor_name_or_path=self.config.name_or_path)
# initialize vte
if is_deepspeed_zero3_enabled():
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
self.vte = VisualEmbedding(self.config.visual_vocab_size, self.config.hidden_size)
else:
self.vte = VisualEmbedding(self.config.visual_vocab_size, self.config.hidden_size,
device=self.visual_tokenizer.vit.device, dtype=self.visual_tokenizer.vit.dtype)
def _merge_modules(modules_list: tuple):
merged_modules = []
for modules in modules_list:
merged_modules.extend(modules if modules else [])
return merged_modules
self._no_split_modules = _merge_modules(
(self.llm._no_split_modules, self.visual_tokenizer.vit._no_split_modules))
self._skip_keys_device_placement = self.llm._skip_keys_device_placement
self._keep_in_fp32_modules = _merge_modules(
(self.llm._keep_in_fp32_modules, self.visual_tokenizer.vit._keep_in_fp32_modules))
self.is_parallelizable = all((self.llm.is_parallelizable, self.visual_tokenizer.vit.is_parallelizable))
self.supports_gradient_checkpointing = True
def tie_weights(self):
self.llm.tie_weights()
def re_init_vte(self, mean, std):
vte = self.get_vte()
rank0_print(BEGIN_LINE)
rank0_print(f'[{datetime.now()}] Before re-initialization of vte: ')
with deepspeed.zero.GatheredParameters([vte.weight]):
rank0_print(f'vte.weight: {vte.weight}')
with deepspeed.zero.GatheredParameters([vte.weight], modifier_rank=0):
if not is_deepspeed_zero3_enabled() or deepspeed.comm.get_rank() == 0:
vte.reset_parameters(mean, std)
rank0_print(f'[{datetime.now()}] After re-initialization of vte:')
with deepspeed.zero.GatheredParameters([vte.weight]):
rank0_print(f'vte.weight: {vte.weight}')
rank0_print(END_LINE)
def get_monitor_tensors(self):
monitor_tensors = dict(
wte=self.get_wte().weight,
lm_head=self.llm.get_output_embeddings().weight,
vte=self.vte.weight
)
monitor_tensors.update(
{f'visual_tokenizer_{k}': v for k, v in self.visual_tokenizer.get_monitor_tensors().items()})
return monitor_tensors
def get_wte(self):
return self.llm.get_input_embeddings()
def get_conversation_formatter(self) -> ConversationFormatter:
if getattr(self, 'conversation_formatter', None) is None:
self.conversation_formatter = getattr(import_module(".conversation_formatter", __package__),
self.config.conversation_formatter_class)(self.text_tokenizer)
return self.conversation_formatter
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
pixel_values: Optional[torch.Tensor],
grid_thws: Optional[torch.Tensor],
labels: Optional[torch.Tensor],
**kwargs
):
inputs_embeds = self.merge_multimodal(
input_ids=input_ids,
pixel_values=pixel_values,
grid_thws=grid_thws,
)
return self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, **kwargs)
def merge_multimodal(
self,
input_ids: torch.Tensor,
pixel_values: Optional[torch.Tensor],
grid_thws: Optional[torch.Tensor],
):
placeholder_token_mask = torch.lt(input_ids, 0)
multimodal_embeds = self.get_wte()(torch.masked_fill(input_ids, placeholder_token_mask, 0))
# We need to create a dummy visual input in two cases:
# 1. During training in a distributed setup (e.g., DDP), to ensure that gradients
# for the visual encoder are synchronized, even if the real input is missing.
# This prevents the backward pass from hanging.
# 2. When using DeepSpeed ZeRO-3, which shards model parameters. A dummy input
# is required even during evaluation to ensure all model parameters are correctly
# gathered and the forward pass can complete.
need_dummy_visual_input = pixel_values is None and (self.training or is_deepspeed_zero3_enabled())
if need_dummy_visual_input:
pixel_values, grid_thws = self.visual_tokenizer.get_dummy_visual_inputs()
if pixel_values is not None:
visual_indicator_embeds = self.vte(torch.tensor(
list(range(self.config.visual_vocab_size - len(INDICATOR_IDS), self.config.visual_vocab_size)),
dtype=torch.long,
device=self.vte.weight.device
)).to(dtype=multimodal_embeds.dtype, device=multimodal_embeds.device)
visual_tokens = self.visual_tokenizer(pixel_values, grid_thws)
visual_embeds = self.vte(visual_tokens).to(dtype=multimodal_embeds.dtype, device=multimodal_embeds.device)
for i, indicator_id in enumerate(INDICATOR_IDS):
multimodal_embeds[input_ids == indicator_id] = visual_indicator_embeds[i]
multimodal_embeds[input_ids == VISUAL_ATOM_ID] = visual_embeds
if need_dummy_visual_input:
multimodal_embeds += visual_embeds.sum() * 0.0 + visual_indicator_embeds.sum() * 0.0
return multimodal_embeds
def _merge_inputs(
self, raw_input_ids, raw_labels, placeholder_indexes, grid_thws, indicator_begin_id, indicator_end_id
):
input_ids = []
labels = []
prev_index = 0
for placeholder_index, grid_thw in zip(placeholder_indexes, grid_thws):
input_ids.extend(raw_input_ids[prev_index:placeholder_index])
labels.extend(raw_labels[prev_index:placeholder_index])
num_image_atoms = grid_thw.prod().item()
num_image_atoms //= self.visual_tokenizer.vit.config.hidden_stride ** 2
num_image_atoms //= self.visual_tokenizer.vit.config.temporal_patch_size
input_ids.extend([indicator_begin_id] + [VISUAL_ATOM_ID] * num_image_atoms + [indicator_end_id])
labels.extend([IGNORE_ID] * (num_image_atoms + 2))
prev_index = placeholder_index + 1
input_ids.extend(raw_input_ids[prev_index:])
labels.extend(raw_labels[prev_index:])
return input_ids, labels
def preprocess_inputs(
self,
text_or_conversations: Union[List[Dict], str],
images: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
videos: Optional[Union[List[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None,
min_pixels=448 * 448,
max_pixels=1344 * 1792,
generation_preface='',
return_labels=False,
frame_selector=None,
# enable_thinking=False,
):
# convert text to conversations
if isinstance(text_or_conversations, str):
conversations = [{
"from": "human",
"value": text_or_conversations
}]
elif isinstance(text_or_conversations, list):
conversations = text_or_conversations
else:
raise ValueError(
f'[{datetime.now()}] Invalid type of `text_or_conversations`, expected `List[Dict]` or `str`,'
f' but got {type(text_or_conversations)}')
# select frame
if frame_selector is not None:
conversations, videos = frame_selector(conversations=conversations, frames=videos, clear_prompt=True)
# format conversations
prompt, raw_input_ids, raw_labels = self.get_conversation_formatter().format(
conversations, generation_preface=generation_preface)
image_token_indexes = [i for i, v in enumerate(raw_input_ids) if v == IMAGE_TOKEN_ID]
video_token_indexes = [i for i, v in enumerate(raw_input_ids) if v == VIDEO_TOKEN_ID]
# merge inputs
input_ids, labels = raw_input_ids, raw_labels
pixel_values, grid_thws = None, None
if images is not None and videos is not None:
raise ValueError(
"Multiple visual input data types detected (both `images` and `videos` provided). "
"This model supports only one type of visual input data at a time. "
"Please provide either `images` or `videos`, but not both."
)
if min(len(image_token_indexes), len(video_token_indexes)) > 0:
raise ValueError(
"Multiple visual modality placeholders detected in text (`<image>` and `<video>`). "
"The input text can contain placeholders for only one type of visual modality at a time. "
"Please use either `<image>` or `<video>` placeholders, but not both."
)
if images is None and videos is None and max(len(image_token_indexes), len(video_token_indexes)) > 0:
raise ValueError(
"Visual modality placeholder(s) detected in the input text "
"(e.g., `<image>` or `<video>`), but no corresponding visual data (`images` or `videos`) was supplied. "
"A visual placeholder requires the corresponding data to be processed. "
"To resolve this issue, please either: "
"1. Remove the visual placeholder(s) from your input text, OR "
"2. Provide the appropriate `images` or `videos` data alongside the text."
)
if images is not None:
images = images if isinstance(images, list) else [images]
pixel_values, grid_thws = zip(
*(self.visual_tokenizer.preprocess(image=image, min_pixels=min_pixels, max_pixels=max_pixels)
for image in images)
)
assert len(image_token_indexes) == len(pixel_values), f"Mismatch in number of image {len(pixel_values)} and `<image>` {len(image_token_indexes)}"
input_ids, labels = self._merge_inputs(
raw_input_ids, raw_labels, image_token_indexes, grid_thws, INDICATOR_IDS[0], INDICATOR_IDS[1]
)
pixel_values = torch.cat(pixel_values, dim=0)
grid_thws = torch.cat(grid_thws, dim=0)
elif videos is not None:
videos = videos if isinstance(videos[0], list) else [videos]
assert len(videos) == 1, "only support single video"
pixel_values, grid_thws = self.visual_tokenizer.preprocess(
video=videos[0], min_pixels=min_pixels, max_pixels=max_pixels
)
assert len(video_token_indexes) == len(videos), f"Mismatch in number of video {len(video_token_indexes)} and `<video>` {len(videos)}"
input_ids, labels = self._merge_inputs(
raw_input_ids, raw_labels, video_token_indexes, grid_thws, INDICATOR_IDS[2], INDICATOR_IDS[3]
)
input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0)
if return_labels:
assert all([label == IGNORE_ID or label >= 0 for label in labels]), "Invalid labels"
labels = torch.tensor(labels, dtype=torch.long).unsqueeze(0)
return prompt, input_ids, pixel_values, grid_thws, labels
else:
return prompt, input_ids, pixel_values, grid_thws
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
state_dict: Optional[dict] = None,
save_function: Callable = torch.save,
push_to_hub: bool = False,
max_shard_size: Union[int, str] = "5GB",
safe_serialization: bool = True,
variant: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
save_peft_format: bool = True,
**kwargs
):
super().save_pretrained(save_directory,
is_main_process=is_main_process,
state_dict=state_dict,
save_function=save_function,
safe_serialization=safe_serialization)
self.text_tokenizer.save_pretrained(save_directory)
self.visual_tokenizer.image_processor.save_pretrained(save_directory)
def generate(
self,
inputs: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
attention_mask = torch.ne(inputs, self.text_tokenizer.pad_token_id).to(device=inputs.device)
inputs_embeds = self.merge_multimodal(
input_ids=inputs,
pixel_values=kwargs.pop('pixel_values', None),
grid_thws=kwargs.pop('grid_thws', None)
)
inputs_embeds = inputs_embeds.detach()
torch.cuda.empty_cache()
return self.llm.generate(inputs=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
@torch.no_grad()
def chat(
self,
prompt: str,
images: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
videos: Optional[Union[List[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None,
do_sample: bool = False,
max_new_tokens: int = 512,
enable_thinking: bool = False,
thinking_budget: Optional[int] = None,
min_pixels: int = 448 * 448, # Parameter for image preprocessing
max_pixels: int = 1792 * 1792, # Parameter for image preprocessing
history: Optional[Dict] = None,
**generate_kwargs, # Allows passing other generation arguments
):
"""
Performs a single turn of conversation, optionally including visual input.
Supports a two-phase generation process with a "thinking_budget" for complex reasoning.
Args:
prompt (str): The user's input prompt.
images (Optional): Optional single image or list of images.
videos (Optional): Optional single video (list of frames) or list of videos.
do_sample (bool): Whether to use sampling during generation.
max_new_tokens (int): The maximum number of new tokens to generate in total.
enable_thinking (bool): If True, enables the model's Chain-of-Thought process.
thinking_budget (Optional[int]): The maximum number of tokens for the "thinking" phase.
If the model doesn't finish thinking within this budget,
it will be forced to start generating the final answer.
min_pixels (int): Minimum total pixels for image processing.
max_pixels (int): Maximum total pixels for image processing.
history (Optional[Dict]): Conversation history.
**generate_kwargs: Additional arguments for the generation method.
Returns:
Tuple[str, str, Dict]: A tuple containing:
- response (str): The final, user-facing response.
- thinking (str): The model's internal thought process (if enable_thinking=True).
- updated_history (Dict): The updated conversation history.
"""
# Initialize history if starting a new conversation
if history is None:
history = {"conversations": [], "images": None, "videos": None}
conversations = history["conversations"] + [{"from": "human", "value": prompt}]
current_images = (images if isinstance(images, list) else [images]) if images is not None else []
combined_images = (history["images"] or []) + current_images
combined_images = combined_images or None
current_videos = (videos if isinstance(videos[0], list) else [videos]) if videos is not None else []
combined_videos = (history["videos"] or []) + current_videos
combined_videos = combined_videos or None
_, initial_input_ids, pixel_values, grid_thws = self.preprocess_inputs(
conversations,
images=combined_images,
videos=combined_videos,
min_pixels=min_pixels,
max_pixels=max_pixels,
generation_preface="<think>\n\n</think>\n\n" if not enable_thinking else ''
)
initial_input_ids = initial_input_ids.to(device=self.device)
if pixel_values is not None:
pixel_values = pixel_values.to(device=self.device, dtype=self.dtype)
if grid_thws is not None:
grid_thws = grid_thws.to(device=self.device)
THINK_END_TOKEN_ID = 151668 # </think>
IM_END_TOKEN_ID = 151645 # <|im_end|>
common_generate_args = {
"pixel_values": pixel_values,
"grid_thws": grid_thws,
"do_sample": do_sample,
"pad_token_id": self.text_tokenizer.pad_token_id,
**generate_kwargs,
}
use_thinking_phase = enable_thinking and thinking_budget is not None and thinking_budget > 0
if not use_thinking_phase:
generated_ids = self.generate(
initial_input_ids,
max_new_tokens=max_new_tokens,
**common_generate_args
)
else:
# stage1: thinking_budget
phase1_output_ids = self.generate(
initial_input_ids,
max_new_tokens=thinking_budget,
**common_generate_args
)
if IM_END_TOKEN_ID in phase1_output_ids[0]:
generated_ids = phase1_output_ids
else:
intermediate_ids = phase1_output_ids
if THINK_END_TOKEN_ID not in intermediate_ids[0]:
early_stop_text = (
"\n\nConsidering the limited time by the user, I have to give the solution "
"based on the thinking directly now.\n</think>\n\n"
)
early_stop_ids = self.text_tokenizer(
early_stop_text, return_tensors="pt", add_special_tokens=False
).input_ids.to(self.device)
intermediate_ids = torch.cat([intermediate_ids, early_stop_ids], dim=1)
# stage2: complete the generation
phase1_tokens_consumed = intermediate_ids.shape[1]
remaining_tokens = max_new_tokens - phase1_tokens_consumed
if remaining_tokens > 0:
combined_input_ids = torch.cat([initial_input_ids, intermediate_ids], dim=1)
phase2_output_ids = self.generate(
combined_input_ids,
max_new_tokens=remaining_tokens,
**common_generate_args
)
generated_ids = torch.cat([intermediate_ids, phase2_output_ids], dim=1)
else:
generated_ids = intermediate_ids
full_generated_ids_list = generated_ids[0].tolist()
thinking, response = "", ""
if enable_thinking:
try:
think_end_idx = full_generated_ids_list.index(THINK_END_TOKEN_ID) + 1
thinking_ids = full_generated_ids_list[:think_end_idx]
response_ids = full_generated_ids_list[think_end_idx:]
thinking = self.text_tokenizer.decode(thinking_ids, skip_special_tokens=True).strip()
response = self.text_tokenizer.decode(response_ids, skip_special_tokens=True).strip()
except ValueError:
response = self.text_tokenizer.decode(full_generated_ids_list, skip_special_tokens=True).strip()
else:
response = self.text_tokenizer.decode(full_generated_ids_list, skip_special_tokens=True).strip()
updated_history = {
"conversations": conversations + [{"from": "gpt", "value": response}],
"images": combined_images,
"videos": combined_videos
}
return response, thinking, updated_history
AutoConfig.register("ovis", OvisConfig)
AutoModelForCausalLM.register(OvisConfig, Ovis)
from typing import Any, Optional
from transformers.configuration_utils import PretrainedConfig
class Siglip2NavitConfig(PretrainedConfig):
"""This is the configuration class to store the configuration of an [`Siglip2Navit`].
Args:
hidden_size: Dimension of the hidden representations.
intermediate_size: Dimension of the SwiGLU representations.
num_hidden_layers: Number of hidden layers in the Transformer.
num_attention_heads: Number of attention heads for each attention layer
in the Transformer.
num_channels: Number of input channels.
image_size: Image size.
patch_size: Patch size.
rms_norm_eps: Epsilon value used for the RMS normalization layer.
attention_dropout: Dropout ratio for attention probabilities.
projection_dropout: Dropout ratio for the projection layer after the attention.
qkv_bias: Whether to add a bias to the queries, keys and values.
use_bias: Whether to add a bias in the feed-forward and projection layers.
kwargs: Keyword arguments for the [`PretrainedConfig`].
"""
model_type: str = "siglip2_navit"
def __init__(
self,
hidden_size: int = 1024,
intermediate_size: int = 4096,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
num_channels: int = 3,
num_patches: int = -1,
image_size: int = 512,
patch_size: int = 16,
hidden_act: str="gelu_pytorch_tanh",
layer_norm_eps: float = 1e-6,
attention_dropout: float = 0.0,
hidden_stride: int = 2,
window_size: int = 112,
fullatt_block_indexes: Optional[list] = None,
temporal_patch_size: int = 1,
preserve_original_pe: bool = True,
use_rope: bool = True,
**kwargs: Any,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.num_patches = num_patches
self.patch_size = patch_size
self.image_size = image_size
self.hidden_act = hidden_act
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_stride = hidden_stride
self.window_size = window_size
self.fullatt_block_indexes = fullatt_block_indexes
self.temporal_patch_size = temporal_patch_size
self.preserve_original_pe = preserve_original_pe
self.use_rope = use_rope
__all__ = ["Siglip2NavitConfig"]
\ No newline at end of file
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/siglip2/modular_siglip2.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_siglip2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import _calculate_fan_in_and_fan_out
from flash_attn import flash_attn_varlen_func
from flash_attn.layers.rotary import apply_rotary_emb
from transformers.activations import ACT2FN
# from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithNoAttention
from transformers.modeling_utils import PreTrainedModel
from .configuration_siglip2_navit import Siglip2NavitConfig
__all__ = ["Siglip2NavitModel"]
# copied from qwen2.5-vl
class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class Siglip2VisionEmbeddings(nn.Module):
def __init__(self, config: Siglip2NavitConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.patch_size = config.patch_size
self.image_size = config.image_size
self.num_patches = config.num_patches
self.preserve_original_pe = config.preserve_original_pe
self.hidden_stride = config.hidden_stride
# siglip2 naflex
if self.num_patches > 0:
self.patch_embedding = nn.Linear(
in_features=config.num_channels * self.patch_size * self.patch_size,
out_features=self.embed_dim,
)
if self.preserve_original_pe:
self.position_embedding_size = int(self.num_patches**0.5)
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
else:
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
if self.preserve_original_pe:
self.num_patches = (self.image_size // self.patch_size) ** 2
self.position_embedding_size = self.image_size // self.patch_size
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
@staticmethod
def resize_positional_embeddings(
positional_embeddings: torch.Tensor,
spatial_shapes: torch.LongTensor,
max_length: int,
) -> torch.Tensor:
"""
Resize positional embeddings to image-specific size and pad to a fixed size.
Args:
positional_embeddings (`torch.Tensor`):
Position embeddings of shape (height, width, embed_dim)
spatial_shapes (`torch.LongTensor`):
Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
max_length (`int`):
Maximum length of the positional embeddings to pad resized positional embeddings to
Returns:
`torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
"""
batch_size = spatial_shapes.shape[0]
embed_dim = positional_embeddings.shape[-1]
source_dtype = positional_embeddings.dtype
resulted_positional_embeddings = torch.empty(
(batch_size, max_length, embed_dim),
device=positional_embeddings.device,
dtype=source_dtype,
)
# (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
# Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
if positional_embeddings.device.type == "cpu":
positional_embeddings = positional_embeddings.to(torch.float32)
for i in range(batch_size):
# (1, dim, height, width) -> (1, dim, target_height, target_width)
height, width = spatial_shapes[i]
resized_embeddings = F.interpolate(
positional_embeddings,
size=(height, width),
mode="bilinear",
align_corners=False,
antialias=True,
)
# (1, dim, target_height, target_width) -> (target_height * target_width, dim)
resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
# Cast to original dtype
resized_embeddings = resized_embeddings.to(source_dtype)
resulted_positional_embeddings[i, : height * width] = resized_embeddings
resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
return resulted_positional_embeddings
def forward(self, pixel_values: torch.FloatTensor,
grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor:
"""
Args:
pixel_values (`torch.FloatTensor`):
Pixel values of shape (num_patches, num_channels * temporal_patch_size * patch_size * patch_size)
grid_thws: (`torch.LongTensor`):
grid shape (num_patches, 3)
"""
# Apply patch embeddings to already patchified pixel values
target_dtype = self.patch_embedding.weight.dtype
if isinstance(self.patch_embedding, nn.Linear):
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
elif isinstance(self.patch_embedding, nn.Conv2d):
pixel_values = pixel_values.view(-1, self.config.num_channels * self.config.temporal_patch_size, self.patch_size,
self.patch_size)
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
patch_embeds = patch_embeds.reshape(-1, self.embed_dim)
if self.preserve_original_pe:
assert grid_thws is not None
pos_embed_new = torch.zeros_like(patch_embeds)
ori_h = ori_w = self.position_embedding_size
positional_embeddings = self.position_embedding.weight.reshape(
self.position_embedding_size, self.position_embedding_size, -1
).unsqueeze(0).permute(0,3,1,2)
# pos_embed = self.pos_embed.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2)
cnt = 0
for t, h, w in grid_thws:
thw = t * h * w
pe = F.interpolate(positional_embeddings, size=(h, w), mode='bicubic', align_corners=False)
pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1)
pe = pe[0].repeat(t, 1)
pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride, w // self.hidden_stride,
self.hidden_stride, -1)
pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(thw, -1)
pos_embed_new[cnt:cnt + thw] = pe
cnt += thw
patch_embeds = patch_embeds + pos_embed_new
return patch_embeds
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# copied from qwen2.5-vl
def apply_rotary_pos_emb_flashatt(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
return q_embed, k_embed
class Siglip2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.use_rope = config.use_rope
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
seq_length, embed_dim = hidden_states.shape
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
queries = queries.view(seq_length, self.num_heads, self.head_dim)
keys = keys.view(seq_length, self.num_heads, self.head_dim)
values = values.view(seq_length, self.num_heads, self.head_dim)
if self.use_rope:
cos, sin = position_embeddings
queries, keys = apply_rotary_pos_emb_flashatt(queries.unsqueeze(0), keys.unsqueeze(0), cos, sin)
queries = queries.squeeze(0)
keys = keys.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
seq_length, -1
)
attn_output = self.out_proj(attn_output)
return attn_output
class Siglip2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Siglip2EncoderLayer(nn.Module):
def __init__(self, config: Siglip2NavitConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.self_attn = Siglip2Attention(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(config)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor
) -> tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Siglip2Encoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`Siglip2EncoderLayer`].
Args:
config: Siglip2NavitConfig
"""
def __init__(self, config: Siglip2NavitConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self.rotary_pos_emb = VisionRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
self.patch_size = config.patch_size
self.hidden_stride = config.hidden_stride
self.window_size = config.window_size
self.spatial_merge_unit = config.hidden_stride * config.hidden_stride
self.fullatt_block_indexes = None if config.fullatt_block_indexes is None else [int(i) for i in config.fullatt_block_indexes.split('|')]
# copied from qwen2.5_vl
def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.hidden_stride,
self.hidden_stride,
w // self.hidden_stride,
self.hidden_stride,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.hidden_stride,
self.hidden_stride,
w // self.hidden_stride,
self.hidden_stride,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def get_window_index(self, grid_thw):
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = self.window_size // self.hidden_stride // self.patch_size # patch (after merge) number in each window
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (
grid_h // self.hidden_stride, # number of patch after merge
grid_w // self.hidden_stride,
)
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
index_padded = index_padded.reshape(
grid_t,
num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size,
)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t,
num_windows_h * num_windows_w,
vit_merger_window_size,
vit_merger_window_size,
)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
def forward(
self,
inputs_embeds,
grid_thws: torch.Tensor,
output_hidden_states: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
rotary_pos_emb = self.rot_pos_emb(grid_thws)
window_index, cu_window_seqlens = self.get_window_index(grid_thws)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=inputs_embeds.device,
dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
seq_len, _ = inputs_embeds.size()
inputs_embeds = inputs_embeds.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
inputs_embeds = inputs_embeds[window_index, :, :]
inputs_embeds = inputs_embeds.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
cu_seqlens = torch.repeat_interleave(grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]).cumsum(
dim=0,
# Select dtype based on the following factors:
# - FA2 requires that cu_seqlens_q must have dtype int32
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
# See https://github.com/huggingface/transformers/pull/34852 for more information
dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
reverse_indices = torch.argsort(window_index)
encoder_states = () if output_hidden_states else None
hidden_states = inputs_embeds
for index, block in enumerate(self.layers):
if self.fullatt_block_indexes is None or index in self.fullatt_block_indexes:
cu_seqlens_tmp = cu_seqlens
else:
cu_seqlens_tmp = cu_window_seqlens
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(block.__call__, hidden_states, cu_seqlens_tmp, position_embeddings)
else:
hidden_states = block(hidden_states, cu_seqlens_tmp, position_embeddings)
if output_hidden_states:
hidden_states_ = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
encoder_states += (hidden_states_[reverse_indices, :].reshape(seq_len, -1),)
# tokens = self.post_trunk_norm(tokens)
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)
return hidden_states, encoder_states
class Siglip2VisionTransformer(nn.Module):
def __init__(self, config: Siglip2NavitConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = Siglip2VisionEmbeddings(config)
self.encoder = Siglip2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
def forward(
self,
pixel_values: torch.FloatTensor,
grid_thws: torch.LongTensor,
output_hidden_states: Optional[bool] = True,
return_dict: Optional[bool] = True,
) -> Union[
Tuple[torch.Tensor],
Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
BaseModelOutputWithNoAttention,
]:
r"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
"""
# output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# output_hidden_states = (
# output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
# )
hidden_states = self.embeddings(pixel_values, grid_thws)
last_hidden_state, hidden_states = self.encoder(hidden_states, grid_thws, output_hidden_states)
last_hidden_state = self.post_layernorm(last_hidden_state)
if not return_dict:
output = (last_hidden_state,)
output += (hidden_states,) if output_hidden_states else ()
return output
return BaseModelOutputWithNoAttention(
last_hidden_state=last_hidden_state,
hidden_states=hidden_states
)
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
) -> torch.Tensor:
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \\leq \text{mean} \\leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsequently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
class Siglip2PreTrainedModel(PreTrainedModel):
config_class = Siglip2NavitConfig
base_model_prefix = "siglip2_navit"
supports_gradient_checkpointing = True
_no_split_modules = [
"Siglip2VisionEmbeddings",
"Siglip2EncoderLayer",
]
_supports_flash_attn_2 = True
_supports_sdpa = False
_supports_flex_attn = False
_supports_attention_backend = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, Siglip2VisionEmbeddings):
width = (
self.config.hidden_size
if isinstance(self.config, Siglip2NavitConfig)
else self.config.hidden_size
)
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
elif isinstance(module, nn.Embedding):
default_flax_embed_init(module.weight)
elif isinstance(module, Siglip2Attention):
nn.init.xavier_uniform_(module.q_proj.weight)
nn.init.xavier_uniform_(module.k_proj.weight)
nn.init.xavier_uniform_(module.v_proj.weight)
nn.init.xavier_uniform_(module.out_proj.weight)
nn.init.zeros_(module.q_proj.bias)
nn.init.zeros_(module.k_proj.bias)
nn.init.zeros_(module.v_proj.bias)
nn.init.zeros_(module.out_proj.bias)
elif isinstance(module, Siglip2MLP):
nn.init.xavier_uniform_(module.fc1.weight)
nn.init.xavier_uniform_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, Siglip2MultiheadAttentionPoolingHead):
nn.init.xavier_uniform_(module.probe.data)
nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
nn.init.zeros_(module.attention.in_proj_bias.data)
elif isinstance(module, Siglip2Model):
logit_scale_init = torch.log(torch.tensor(1.0))
module.logit_scale.data.fill_(logit_scale_init)
module.logit_bias.data.zero_()
elif isinstance(module, Siglip2ForImageClassification):
nn.init.normal_(
module.classifier.weight,
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class Siglip2NavitModel(Siglip2PreTrainedModel):
config_class = Siglip2NavitConfig
main_input_name = "pixel_values"
def __init__(self, config: Siglip2NavitConfig):
super().__init__(config)
self.vision_model = Siglip2VisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def forward(
self,
pixel_values: torch.FloatTensor,
grid_thws: torch.LongTensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[
Tuple[torch.Tensor],
Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
BaseModelOutputWithNoAttention,
]:
r"""
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
Mask to avoid performing attention on padding pixel indices.
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Siglip2VisionModel
>>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled features
```"""
if output_hidden_states is None:
output_hidden_states = self.config.output_hidden_states
if return_dict is None:
return_dict = self.config.use_return_dict
return self.vision_model(
pixel_values=pixel_values,
grid_thws=grid_thws,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def _get_block(self, layer_index):
return self.vision_model.encoder.layers[layer_index]
def _get_attn_weight(self, layer_index):
return torch.cat([self._get_block(layer_index).self_attn.q_proj.weight,
self._get_block(layer_index).self_attn.k_proj.weight,
self._get_block(layer_index).self_attn.v_proj.weight])
def _get_pose_embed(self):
return self.vision_model.embeddings.position_embedding.weight
\ No newline at end of file
import torch
from PIL import Image
from ovis.model.modeling_ovis import Ovis
# If you need video support, make sure moviepy is installed first:
# pip install moviepy==1.0.3
try:
from moviepy.editor import VideoFileClip # type: ignore
_HAS_MOVIEPY = True
except Exception:
_HAS_MOVIEPY = False
def run_single_image_example(model: Ovis, image_path: str) -> None:
"""
Run an inference example with a single image input.
"""
print("--- 1) Single-image example ---")
images = [Image.open(image_path).convert("RGB")]
prompt = "<image>\nDescribe this image in detail."
print(f"Prompt:\n{prompt}")
response, _, _ = model.chat(
prompt=prompt,
images=images,
min_pixels=448 * 448,
max_pixels=1792 * 1792,
videos=None,
do_sample=True,
max_new_tokens=1024,
)
print(f"\nResponse:\n{response}")
def run_multi_image_example(model: Ovis, image_paths: list) -> None:
"""
Run an inference example with multiple image inputs.
"""
print("--- 2) Multi-image example ---")
images = [Image.open(p).convert("RGB") for p in image_paths]
prompt = "<image>\n<image>\n<image>\nWhat is the relationship between the third image and the first two?"
print(f"Prompt:\n{prompt}")
response, _, _ = model.chat(
prompt=prompt,
images=images,
min_pixels=448 * 448,
max_pixels=896 * 896,
videos=None,
do_sample=True,
max_new_tokens=1024,
)
print(f"\nResponse:\n{response}")
def run_video_example(model: Ovis, video_path: str, num_frames: int = 8) -> None:
"""
Run an inference example with a video input.
"""
if not _HAS_MOVIEPY:
raise ImportError(
"moviepy is not installed. Install it with `pip install moviepy==1.0.3` to use video examples."
)
print("--- 3) Video example ---")
with VideoFileClip(video_path) as clip:
total_frames = int(clip.fps * clip.duration)
indices = [int(i * total_frames / num_frames) for i in range(num_frames)]
frames = [
Image.fromarray(clip.get_frame(t)) for t in (index / clip.fps for index in indices)
]
videos = [frames]
prompt = "<video>\nDescribe this video in detail."
print(f"Prompt:\n{prompt}")
response, _, _ = model.chat(
prompt=prompt,
images=None,
videos=videos,
min_pixels=448 * 448,
max_pixels=896 * 896,
do_sample=True,
max_new_tokens=1024,
)
print(f"\nResponse:\n{response}")
def run_text_only_example(model: Ovis) -> None:
"""
Run an inference example with text-only input.
"""
print("--- 4) Text-only example ---")
prompt = "Hi, please introduce Huangshan (Yellow Mountain) in Chinese."
print(f"Prompt:\n{prompt}")
response, _, _ = model.chat(
prompt=prompt,
images=None,
videos=None,
do_sample=True,
max_new_tokens=1024,
)
print(f"\nResponse:\n{response}")
if __name__ == "__main__":
# --- 1) Load model ---
model_path = "AIDC-AI/Ovis2.5-9B"
print("Loading model, please wait...")
model = (
Ovis.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="cuda:0",
).eval()
)
print("Model loaded.")
print("\n========================================\n")
# --- 2) Define file paths (anonymized placeholders) ---
# Replace the following with your own paths
single_image_file = "/path/to/image1.jpg"
multi_image_files = [
"/path/to/image1.jpg",
"/path/to/image2.jpg",
"/path/to/image3.png",
]
video_file = "/path/to/video1.mp4"
# --- 3) Run examples ---
run_single_image_example(model, single_image_file)
print("\n========================================\n")
run_multi_image_example(model, multi_image_files)
print("\n========================================\n")
run_video_example(model, video_file)
print("\n========================================\n")
run_text_only_example(model)
print("\n========================================\n")
import torch
from PIL import Image
from ovis.model.modeling_ovis import Ovis
MODEL_PATH = "AIDC-AI/Ovis2.5-9B"
# Enable reflective reasoning mode (thinking mode)
enable_thinking = True
# Total tokens = thinking phase + response
max_new_tokens = 3072
# thinking_budget: upper bound of tokens reserved for the "thinking phase"
# - If provided, the model will stop thinking once this budget is reached,
# then switch to generating the final response.
# - If omitted when calling .chat(), it is equivalent to "not set",
# and the model may use all max_new_tokens for thinking.
thinking_budget = 2048
# Load model
model = Ovis.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
device_map="cuda"
).eval()
prompt = "<image>\nDescribe this image in detail."
images = [Image.open("/path/to/image1.jpg")]
# Run chat
response, thinking, _ = model.chat(
prompt=prompt,
images=images,
history=None,
do_sample=True,
max_new_tokens=max_new_tokens,
enable_thinking=enable_thinking,
thinking_budget=thinking_budget, # omit this arg => unlimited thinking
)
# Print results
if enable_thinking and thinking:
print("=== Thinking ===")
print(thinking)
print("\n=== Response ===")
print(response)
else:
print("Response:", response)
import argparse
from typing import List, Optional, Tuple
import PIL.Image
import gradio as gr
import moviepy.editor as mp
import numpy as np
import torch
from ovis.model.modeling_ovis import Ovis
model: Ovis = None
def load_video_frames(video_path: Optional[str], n_frames: int = 8) -> Optional[List[PIL.Image.Image]]:
"""Extract a fixed number of frames from the video file."""
if not video_path:
return None
try:
with mp.VideoFileClip(video_path) as clip:
duration = clip.duration
if duration is None or clip.fps is None or duration <= 0 or clip.fps <= 0:
print(f"Warning: Unable to process video {video_path}. Invalid duration or fps.")
return None
total_possible_frames = int(duration * clip.fps)
num_to_extract = min(n_frames, total_possible_frames)
if num_to_extract <= 0:
print(f"Warning: Cannot extract frames from {video_path}. Computed extractable frames is zero.")
return None
frames = []
timestamps = np.linspace(0, duration, num_to_extract, endpoint=True)
for t in timestamps:
frame_np = clip.get_frame(t)
frames.append(PIL.Image.fromarray(frame_np))
print(f"Successfully extracted {len(frames)} frames from {video_path}.")
return frames
except Exception as e:
print(f"Error processing video {video_path}: {e}")
return None
def run_single_model(
image_input: Optional[PIL.Image.Image],
video_input: Optional[str],
prompt: str,
do_sample: bool,
max_new_tokens: int,
enable_thinking: bool
) -> str:
"""Run single model inference."""
if not prompt and not image_input and not video_input:
gr.Warning("Please enter a prompt, upload an image, or upload a video.")
return ""
# Prepare vision inputs
images = [image_input] if image_input else None
video_frames = load_video_frames(video_input)
videos = [video_frames] if video_frames else None
# Construct full prompt with placeholders
visual_placeholders = ('<image>\n' * len(images) if images else "") + ('<video>\n' if videos else "")
full_prompt = visual_placeholders + prompt
# Call model chat method
response, thinking, _ = model.chat(
prompt=full_prompt,
history=None, # Always start a new conversation
images=images,
videos=videos,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
enable_thinking=enable_thinking,
)
# Format output
if enable_thinking and thinking:
return f"**Thinking:**\n```text\n{thinking}\n```\n\n**Response:**\n{response}"
return response
def toggle_media_input(choice: str) -> Tuple[gr.update, gr.update]:
"""Toggle visibility of image and video input components."""
if choice == "Image":
return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
else:
return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
def clear_interface() -> Tuple[str, None, None, str, str]:
"""Reset all inputs and outputs."""
return "", None, None, "", "Image"
def start_generation() -> Tuple[gr.update, gr.update, gr.update]:
"""Update UI status when generation starts."""
return (
gr.update(value="⏳ Generating...", interactive=False),
gr.update(interactive=False),
gr.update(value="⏳ Model is generating...")
)
def finish_generation() -> Tuple[gr.update, gr.update]:
"""Restore UI status after generation ends."""
return gr.update(value="Generate", interactive=True), gr.update(interactive=True)
def build_demo(model_path: str, gpu: int):
"""Build single-model Gradio demo interface."""
global model
device = f"cuda:{gpu}"
print(f"Loading model {model_path} to device {device}...")
model = Ovis.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map=device).eval()
print("Model loaded successfully.")
custom_css = "#output_md .prose { font-size: 18px !important; }"
with gr.Blocks(theme=gr.themes.Default(), css=custom_css) as demo:
gr.Markdown("# Multimodal Large Language Model Interface")
gr.Markdown(f"Running on **GPU {gpu}**. Each submission starts a new conversation.")
with gr.Row():
# Left column - inputs
with gr.Column(scale=1):
gr.Markdown("### Inputs")
input_type_radio = gr.Radio(
choices=["Image", "Video"], value="Image", label="Select Input Type"
)
image_input = gr.Image(label="Image Input", type="pil", visible=True, height=400)
video_input = gr.Video(label="Video Input", visible=False)
prompt_input = gr.Textbox(
label="Prompt", placeholder="Enter your prompt here... (Press Enter to submit)", lines=3
)
with gr.Accordion("Generation Settings", open=True):
do_sample = gr.Checkbox(label="Enable Sampling (Do Sample)", value=False)
max_new_tokens = gr.Slider(
minimum=32, maximum=2048, value=1024, step=32, label="Max New Tokens"
)
enable_thinking = gr.Checkbox(label="Deep Thinking", value=False)
with gr.Row():
clear_btn = gr.Button("Clear", variant="secondary", scale=1)
generate_btn = gr.Button("Generate", variant="primary", scale=2)
# Right column - output
with gr.Column(scale=2):
model_name = model_path.split('/')[-1]
gr.Markdown(f"### Model Output\n`{model_name}`")
output_display = gr.Markdown(elem_id="output_md")
# Event handlers
input_type_radio.change(
fn=toggle_media_input,
inputs=input_type_radio,
outputs=[image_input, video_input]
)
run_inputs = [image_input, video_input, prompt_input, do_sample, max_new_tokens, enable_thinking]
generate_btn.click(
fn=start_generation,
outputs=[generate_btn, clear_btn, output_display]
).then(
fn=run_single_model,
inputs=run_inputs,
outputs=[output_display]
).then(
fn=finish_generation,
outputs=[generate_btn, clear_btn]
)
prompt_input.submit(
fn=start_generation,
outputs=[generate_btn, clear_btn, output_display]
).then(
fn=run_single_model,
inputs=run_inputs,
outputs=[output_display]
).then(
fn=finish_generation,
outputs=[generate_btn, clear_btn]
)
clear_btn.click(
fn=clear_interface,
outputs=[output_display, image_input, video_input, prompt_input, input_type_radio]
).then(
fn=toggle_media_input,
inputs=input_type_radio,
outputs=[image_input, video_input]
)
return demo
def parse_args():
parser = argparse.ArgumentParser(description="Gradio interface for Ovis.")
parser.add_argument("--model-path", type=str)
parser.add_argument("--gpu", type=int, default=0, help="GPU index to run the model.")
parser.add_argument("--port", type=int, default=9901, help="Port to run the Gradio service.")
parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Server name for Gradio app.")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
demo = build_demo(
model_path=args.model_path,
gpu=args.gpu
)
print(f"Launching Gradio app at http://{args.server_name}:{args.port}")
demo.queue().launch(
server_name=args.server_name,
server_port=args.port,
share=False,
ssl_verify=False
)
from dataclasses import dataclass, field
from typing import Optional
import transformers
from ovis.util.utils import rankN_print
@dataclass
class ModelArguments:
llm_name_or_path: Optional[str] = field(default=None)
vit_name_or_path: Optional[str] = field(default=None)
visual_vocab_size: int = field(default=65536)
conversation_formatter_class: str = field(default=None)
attn_implementation: Optional[str] = field(default=None)
accepts_loss_kwargs: bool = field(default=True)
vit_hidden_stride: int = field(default=2)
vit_window_size: int = field(default=112)
vit_temporal_patch_size: int = field(default=1)
vit_fullatt_block_indexes: Optional[str] = field(default=None)
vit_preserve_original_pe: Optional[bool] = field(default=True)
vit_use_rope: Optional[bool] = field(default=True)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
data_info_version: Optional[str] = field(default=None)
data_name: Optional[str] = field(default=None) # a|b|c
data_type: Optional[str] = field(default=None) # caption, conversation
ovis_pretrained_path: Optional[str] = field(default=None)
stage: Optional[int] = field(default=None)
train_modules: Optional[str] = field(default=None)
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
save_safetensors: bool = field(default=True)
monitor_step: int = field(default=100)
model_init_seed: int = field(default=0)
multimodal_max_length: int = field(default=4096)
text_max_length: Optional[int] = field(default=4096)
min_frames: int = field(default=8)
max_frames: int = field(default=8)
overall_ratio: Optional[str] = field(default=None)
mix_data_name: Optional[str] = field(default=None)
mix_ratio: Optional[float] = field(default=None)
min_lr_rate: Optional[float] = field(default=None)
single_image_min_pixels: int = field(default=448*448)
single_image_max_pixels: int = field(default=1792*1344)
multiple_image_min_pixels: int = field(default=448*448)
multiple_image_max_pixels: int = field(default=448*448)
video_min_pixels: int = field(default=448*448)
video_max_pixels: int = field(default=448*448)
def __post_init__(self):
if self.min_lr_rate is not None:
self.lr_scheduler_kwargs = {
"min_lr_rate": self.min_lr_rate
}
if self.gradient_checkpointing:
self.gradient_checkpointing_kwargs = {"use_reentrant": False}
if self.stage < 3:
self.save_safetensors = False
super().__post_init__()
assert self.model_init_seed != self.seed, "`model_init_seed` should be different from `seed`"
\ No newline at end of file
import gc
import time
import deepspeed
import torch
import torch.distributed as dist
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from ovis.util.constants import END_LINE, BEGIN_LINE
from ovis.util.utils import rankN_print
class TuneTauCallback(TrainerCallback):
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
visual_tokenizer = kwargs['model'].get_visual_tokenizer()
current_step = state.global_step
max_step = state.max_steps
ratio = current_step / max_step
visual_tokenizer.config.tau = args.visual_max_tau - (args.visual_max_tau - args.visual_min_tau) * ratio
class MonitorCallback(TrainerCallback):
def _monitoring(self, model, step):
with torch.no_grad():
with deepspeed.zero.GatheredParameters(model.get_monitor_tensors().values()):
for k, v in model.get_monitor_tensors().items():
rankN_print(BEGIN_LINE)
rankN_print(f'{k} @ step {step} with sum: {v.sum().item()} and content: ')
rankN_print(v)
rankN_print(END_LINE)
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
model = kwargs['model']
step = state.global_step
if step % args.monitor_step == 0 or step == 10: # monitor at step 10 for fast check
self._monitoring(model, step)
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
model = kwargs['model']
step = state.global_step
self._monitoring(model, step)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment