Commit 441126fe authored by luopl's avatar luopl
Browse files

"Initial commit"

parents
Pipeline #3069 canceled with stages
import logging
from datetime import datetime
from typing import Dict
import pandas
import torch
from ovis.train.dataset.multimodal_dataset import MultimodalDataset
from ovis.util.constants import IMAGE_TOKEN, IGNORE_ID
from ovis.util.utils import rank0_print
class CaptionDataset(MultimodalDataset):
def load(self):
rank0_print(f"[{datetime.now()}] Loading dataset {self.name} from {self.meta_file} begin")
samples = pandas.read_parquet(self.meta_file, engine='pyarrow')
rank0_print(f"[{datetime.now()}] Loading dataset {self.name} end")
return samples
def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
sample = self.samples[i]
image_path = sample['image']
if isinstance(image_path, list):
assert len(image_path) == 1
image_path = image_path[0]
text = sample['caption'].replace(IMAGE_TOKEN, '').strip()
caption_template = sample['caption_template']
# process text
head, tail = caption_template.split(IMAGE_TOKEN)
head_ids = self.text_tokenizer(head, add_special_tokens=False).input_ids
tail_ids = self.text_tokenizer(tail, add_special_tokens=False).input_ids
text_ids = self.text_tokenizer(text, add_special_tokens=False).input_ids
# process image
try:
image, last_e = self.read_image(image_path)
pixel_values, grid_thws = self.visual_tokenizer.preprocess(
image=image,
min_pixels=self.training_args.single_image_min_pixels,
max_pixels=self.training_args.single_image_max_pixels
)
num_image_atoms = grid_thws[0].prod().item()
num_image_atoms //= self.visual_tokenizer.vit.config.hidden_stride ** 2
num_image_atoms //= self.visual_tokenizer.vit.config.temporal_patch_size
image_placeholders = [INDICATOR_IDS[0]] + [VISUAL_ATOM_ID] * num_image_atoms + [INDICATOR_IDS[1]]
input_ids = head_ids + image_placeholders + tail_ids + text_ids
labels = [IGNORE_ID] * (len(input_ids) - len(text_ids)) + text_ids
assert self.text_tokenizer.pad_token_id not in input_ids, \
"The sample's text contains a padding token: `{self.text_tokenizer.pad_token}`"
except Exception as e:
logging.exception(f'processing smaple failed with i: {i}, idx: {idx}, image_path: {image_path}')
pixel_values, grid_thws = None, None
input_ids = [0]
labels = [IGNORE_ID]
input_ids = input_ids[:self.training_args.multimodal_max_length]
labels = labels[:self.training_args.multimodal_max_length]
input_ids = torch.tensor(input_ids, dtype=torch.long)
attention_mask = torch.full_like(input_ids, fill_value=True, dtype=torch.bool)
labels = torch.tensor(labels, dtype=torch.long)
return dict(
input_ids=input_ids,
pixel_values=pixel_values,
grid_thws=grid_thws,
attention_mask=attention_mask,
labels=labels
)
import copy
import json
import logging
from datetime import datetime
from typing import Dict
import torch
from ovis.train.dataset.multimodal_dataset import MultimodalDataset
from ovis.util.constants import VIDEO_TOKEN, IMAGE_TOKEN, IGNORE_ID
from ovis.util.utils import rank0_print
class ConversationDataset(MultimodalDataset):
def load(self):
rank0_print(f"[{datetime.now()}] Loading dataset {self.name} from {self.meta_file} begin")
with open(self.meta_file, 'r', encoding='utf-8') as f:
samples = json.load(f)
rank0_print(f'#samples: {len(samples)}')
rank0_print(f'sample: {samples[0]}')
rank0_print(f"[{datetime.now()}] Loading dataset {self.name} end")
return samples
def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
sample = self.samples[i]
conversations = sample["conversations"]
# try:
images = None
videos = None
n_image_or_frame = 0
if 'image' in sample:
images = []
image_paths = sample['image']
if isinstance(image_paths, str):
image_paths = [image_paths]
for image_path in image_paths:
image, last_e = self.read_image(image_path)
assert image is not None, f"Failed to read image from {image_path}"
images.append(image)
n_image_or_frame = len(images)
elif 'video' in sample or 'video_frames' in sample:
video, last_e = self.read_video(sample, min_frames=self.min_frames, max_frames=self.max_frames)
video_path = sample.get('video') or sample.get('video_frames')
assert video is not None, f"Failed to read video from {video_path}"
videos = [video]
n_image_or_frame = len(video)
if images is None and videos is None:
min_pixels = 0
max_pixels = 0
elif videos is not None:
min_pixels = self.training_args.video_min_pixels
max_pixels = self.training_args.video_max_pixels
elif len(images) == 1:
min_pixels = self.training_args.single_image_min_pixels
max_pixels = self.training_args.single_image_max_pixels
else:
min_pixels = self.training_args.multiple_image_min_pixels
max_pixels = self.training_args.multiple_image_max_pixels
if min_pixels < 0:
min_pixels = self.training_args.single_image_min_pixels
if max_pixels < 0:
max_pixels = max(min_pixels, self.training_args.single_image_max_pixels // n_image_or_frame)
prompt, input_ids, pixel_values, grid_thws, labels = self.model.preprocess_inputs(
conversations,
images=images,
videos=videos,
min_pixels=min_pixels,
max_pixels=max_pixels,
generation_preface=None,
return_labels=True,
)
if pixel_values is None:
input_ids, pixel_values, grid_thws, labels = self.truncate_inputs(
input_ids, pixel_values, grid_thws, labels, max_length=self.training_args.text_max_length
)
else:
input_ids, pixel_values, grid_thws, labels = self.truncate_inputs(
input_ids, pixel_values, grid_thws, labels, max_length=self.training_args.multimodal_max_length
)
assert self.text_tokenizer.pad_token_id not in input_ids, \
"The sample's text contains a padding token: `{self.text_tokenizer.pad_token}`"
del sample
return dict(
input_ids=input_ids,
pixel_values=pixel_values,
grid_thws=grid_thws,
attention_mask=torch.full_like(input_ids, fill_value=True, dtype=torch.bool),
labels=labels
)
import json
import logging
import os
import traceback
from typing import Dict, Sequence, Union, List
import numpy as np
import torch
import moviepy.editor as mp
from PIL import Image
import torch.nn.functional as F
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
from ovis.model.modeling_ovis import Ovis
from ovis.train.arguments import TrainingArguments
from ovis.util.constants import IGNORE_ID, BEGIN_LINE, END_LINE, VISUAL_ATOM_ID, INDICATOR_IDS
class MultimodalDataset(Dataset):
def __init__(self, name: str, info: Dict, model: Ovis, training_args: TrainingArguments):
self.name = name
self.model = model
self.training_args = training_args
self.meta_file = info['meta_file']
self.image_dir = info['image_dir']
self.caption_template = info.get('caption_template', None)
self.text_tokenizer = self.model.text_tokenizer
self.visual_tokenizer = self.model.visual_tokenizer
self.text_max_length = training_args.text_max_length
self.min_frames = training_args.min_frames
self.max_frames = training_args.max_frames
self.samples = self.load()
def load(self):
raise NotImplementedError
def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
raise NotImplementedError
def __len__(self):
return len(self.samples)
def read_image(self, path):
try:
full_path = os.path.join(self.image_dir, path)
image = Image.open(full_path).convert('RGB')
return image, None
except Exception as e:
return None, e
def read_video(self, sample, min_frames, max_frames):
def _sampling_idx(_len, _min, _max):
if _len < _min or _len > _max:
tgt_len = _min if _len < _min else _max
stride = _len / tgt_len
sampled_ids = []
for i in range(tgt_len):
start = int(np.round(stride * i))
end = int(np.round(stride * (i + 1)))
sampled_ids.append(min(_len - 1, (start + end) // 2))
return sampled_ids
else:
return list(range(_len))
if "video_frames" in sample:
frames = []
frames_paths = sample['video_frames']
sampled_ids = _sampling_idx(len(frames_paths), min_frames, max_frames)
for idx in sampled_ids:
frame, last_e = self.read_image(os.path.join(self.image_dir, frames_paths[idx]))
if frame is None:
return None, last_e
frames.append(frame)
return frames, None
elif "video" in sample:
video_path = os.path.join(self.image_dir, sample['video'])
max_tries = 2
last_e = None
for _ in range(max_tries):
try:
with mp.VideoFileClip(video_path) as clip:
total_frames = int(clip.fps * clip.duration)
sampled_ids = _sampling_idx(total_frames, min_frames, max_frames)
frames = [clip.get_frame(idx / clip.fps) for idx in sampled_ids]
frames = [Image.fromarray(frame, mode='RGB') for frame in frames]
if len(frames) == 0 or any(frame.size[0] < 5 or frame.size[1] < 5 for frame in frames):
raise ValueError("frames are empty or there exists very small frame")
return frames, None
except Exception as e:
last_e = f"read video error: {e}\n detailed info: {traceback.format_exc()}"
return None, last_e
else:
return None, RuntimeError(f"missing `video_frames` and `video` in sample: {json.dumps(sample)}")
def truncate_inputs(
self, input_ids, pixel_values, grid_thws, labels, max_length
):
input_ids = input_ids[0, :max_length]
labels = labels[0, :max_length]
if input_ids[-1] in (VISUAL_ATOM_ID, INDICATOR_IDS[0], INDICATOR_IDS[2]): # incomplete visual input
last_text_id_pos = (input_ids >= 0).nonzero()[-1].item() + 1
input_ids = input_ids[:last_text_id_pos]
labels = labels[:last_text_id_pos]
num_visual_atom = torch.eq(input_ids, VISUAL_ATOM_ID).sum().item()
if num_visual_atom > 0:
vit = self.model.visual_tokenizer.vit
ratio = vit.config.temporal_patch_size * vit.config.hidden_stride ** 2
cumsum_patches = grid_thws.prod(dim=1).cumsum(dim=0)
last_grid_index = (cumsum_patches // ratio == num_visual_atom).nonzero().item()
pixel_values = pixel_values[:cumsum_patches[last_grid_index]]
grid_thws = grid_thws[:last_grid_index+1]
else:
pixel_values, grid_thws = None, None
return input_ids, pixel_values, grid_thws, labels
class DataCollatorForMultimodalDataset:
def __init__(self, text_tokenizer: PreTrainedTokenizer):
self.text_tokenizer = text_tokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
keys = ("input_ids", "pixel_values", "grid_thws", "attention_mask", "labels")
input_ids, pixel_values, grid_thws, attention_mask, labels = (
tuple(instance[key] for instance in instances) for key in keys
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.text_tokenizer.pad_token_id
)
pixel_values = [x for x in pixel_values if x is not None]
pixel_values = torch.cat(pixel_values, dim=0) if len(pixel_values) > 0 else None
grid_thws = [x for x in grid_thws if x is not None]
grid_thws = torch.cat(grid_thws, dim=0) if len(grid_thws) > 0 else None
attention_mask = torch.nn.utils.rnn.pad_sequence(
attention_mask,
batch_first=True,
padding_value=False
)
labels = torch.nn.utils.rnn.pad_sequence(
labels,
batch_first=True,
padding_value=IGNORE_ID
)
if 0 not in attention_mask:
input_ids = F.pad(input_ids, (0, 1), value=self.text_tokenizer.pad_token_id)
attention_mask = F.pad(attention_mask, (0, 1), value=False)
labels = F.pad(labels, (0, 1), value=IGNORE_ID)
if torch.all(labels == IGNORE_ID):
logging.warning(f'[DataCollatorForMultimodalDataset] All samples in the current batch are ignored.')
return dict(
input_ids=input_ids,
pixel_values=pixel_values,
grid_thws=grid_thws,
attention_mask=attention_mask,
labels=labels
)
{
"geometry3k_local": {
"meta_file": "path/to/geometry3k_local.json",
"storage_type": "hybrid",
"data_format": "conversation",
"image_dir": "path/to/images/"
}
}
\ No newline at end of file
import json
import os
import pathlib
import deepspeed
import flash_attn
import torch
import transformers
from deepspeed import get_accelerator
from torch.utils.data import ConcatDataset
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel, Trainer
from transformers.trainer_utils import set_seed
from ovis.model.configuration_ovis import OvisConfig
from ovis.model.modeling_ovis import Ovis, VisualTokenizer
from ovis.train.arguments import ModelArguments, TrainingArguments
from ovis.train.callback import MonitorCallback
from ovis.train.dataset.caption_dataset import CaptionDataset
from ovis.train.dataset.conversation_dataset import ConversationDataset
from ovis.train.dataset.multimodal_dataset import DataCollatorForMultimodalDataset
from ovis.util.constants import BEGIN_LINE, END_LINE
from ovis.util.utils import smart_unit, rank0_print, rankN_print, replace_torch_load_with_weights_only_false
def load_model(model_args: ModelArguments, training_args: TrainingArguments):
model, loading_info = Ovis.from_pretrained(
training_args.ovis_pretrained_path,
output_loading_info=True,
trust_remote_code=True
)
rankN_print(BEGIN_LINE)
rankN_print(f'Loading info of Ovis:\n{loading_info}')
rankN_print(END_LINE)
model.accepts_loss_kwargs = model_args.accepts_loss_kwargs
if model_args.attn_implementation:
model.llm.config._attn_implementation = model_args.attn_implementation
model.visual_tokenizer.vit.config._attn_implementation = model_args.attn_implementation
model.llm.config.use_cache = False
model.config.use_cache = False
rank0_print(BEGIN_LINE)
rank0_print(f'model.config:\n{model.config}')
rank0_print(END_LINE)
return model
def load_data(model: Ovis, training_args: TrainingArguments):
# construct data module
if training_args.data_type == 'caption':
train_dataset = CaptionDataset(model, training_args)
elif training_args.data_type == 'conversation':
train_dataset = ConversationDataset(model, training_args)
else:
raise ValueError(f'Invalid data type: {training_args.data_type}')
data_module = dict(
train_dataset=train_dataset,
data_collator=DataCollatorForMultimodalDataset(model.text_tokenizer, training_args)
)
return data_module
def train(model_args: ModelArguments, training_args: TrainingArguments):
# save args to checkpoint dir
with training_args.main_process_first(local=False):
if training_args.process_index == 0:
def args2dict(args):
return {k: str(v) for k, v in args.__dict__.items()}
args_log = json.dumps(dict(
model_args=args2dict(model_args),
training_args=args2dict(training_args)
), ensure_ascii=False, indent=2)
print(args_log)
os.makedirs(training_args.output_dir, exist_ok=True)
with open(os.path.join(training_args.output_dir, 'model_training_args.json'), 'w',
encoding='utf-8') as f:
f.write(args_log + '\n')
# load model & data
model = load_model(model_args, training_args)
# select train modules, support different learning rate for different modules
model.requires_grad_(False)
parameters = []
for module_name_lr in training_args.train_modules.split('|'):
module_name_lr = module_name_lr.replace(' ', '').split(':')
module_lr = training_args.learning_rate
if len(module_name_lr) == 2:
module_name, module_lr = module_name_lr[0], float(module_name_lr[1])
elif len(module_name_lr) == 1:
module_name = module_name_lr[0]
else:
raise ValueError
match module_name:
case 'all':
module = model
case 'llm':
module = model.llm
case 'visual_tokenizer':
module = model.visual_tokenizer
case 'visual_tokenizer.head':
module = model.visual_tokenizer.head
case 'visual_tokenizer.vit':
module = model.visual_tokenizer.vit
case 'visual_tokenizer.vit.last_block':
module = model.visual_tokenizer._get_last_block()
case 'vte':
module = model.vte
case _:
raise ValueError(f'Invalid train module name: {module_name}')
module.requires_grad_(True)
parameters.append({'params': module.parameters(), 'lr': module_lr})
optimizer = torch.optim.AdamW(parameters, lr=training_args.learning_rate, weight_decay=training_args.weight_decay)
rank0_print(BEGIN_LINE)
rank0_print('Parameters to train:')
param_lr_mapping = {}
for group in optimizer.param_groups:
lr = group['lr']
for param in group['params']:
param_lr_mapping[param] = lr
rank0_print(f'LLM\'s attn implementation: {model.llm.config._attn_implementation}')
rank0_print(
f'ViT\'s attn implementation: {model.visual_tokenizer.vit.config._attn_implementation}'
)
rank0_print(END_LINE)
# construct data module
datasets = []
dataset_info_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
f'dataset/{training_args.data_info_version}.json')
with open(dataset_info_path, 'r', encoding='utf-8') as f:
dataset_info = json.load(f)
for name in training_args.data_name.split('|'):
info = dataset_info[name]
data_format = info['data_format']
if data_format == 'caption':
dataset = CaptionDataset(name, info, model, training_args)
elif data_format == 'conversation':
dataset = ConversationDataset(name, info, model, training_args)
else:
raise ValueError(f'Invalid data format `{data_format}` for dataset `{name}`')
datasets.append(dataset)
data_module = dict(
train_dataset=ConcatDataset(datasets),
data_collator=DataCollatorForMultimodalDataset(model.text_tokenizer)
)
# train
trainer = Trainer(
model=model,
args=training_args,
callbacks=[MonitorCallback],
**data_module
)
rankN_print(BEGIN_LINE)
rankN_print(f'model_accepts_loss_kwargs: {trainer.model_accepts_loss_kwargs}')
rankN_print(END_LINE)
rankN_print(BEGIN_LINE)
rankN_print('Dataset sample tensor:')
rankN_print(data_module['train_dataset'][0])
rankN_print(END_LINE)
rankN_print(BEGIN_LINE)
rankN_print('Dataset sample input_ids decoding:')
rankN_print(model.text_tokenizer.decode([x for x in data_module['train_dataset'][0]['input_ids'] if x >= 0]))
rankN_print(END_LINE)
rankN_print(BEGIN_LINE)
rankN_print('Dataset sample labels decoding:')
rankN_print(model.text_tokenizer.decode([x for x in data_module['train_dataset'][0]['labels'] if x >= 0]))
rankN_print(END_LINE)
rankN_print(BEGIN_LINE)
rankN_print(f'#param of model: {smart_unit(model.num_parameters())}')
rankN_print(f'#param of llm: {smart_unit(model.llm.num_parameters())}')
rankN_print(f'#param of vit: {smart_unit(model.visual_tokenizer.vit.num_parameters())}')
rankN_print(f'#param of vte: {smart_unit(model.vte.weight.numel())}')
rankN_print(f'#dtype of model: {model.dtype}')
rankN_print(f'#dtype of llm: {model.llm.dtype}')
rankN_print(f'#dtype of vit: {model.visual_tokenizer.vit.dtype}')
rankN_print(f'#dtype of vte: {model.vte.weight.dtype}')
rankN_print(END_LINE)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
# save model
model.llm.config.use_cache = True
model.config.use_cache = True
trainer.save_model()
if __name__ == '__main__':
replace_torch_load_with_weights_only_false()
parser = transformers.HfArgumentParser(
(ModelArguments, TrainingArguments))
model_args, training_args = parser.parse_args_into_dataclasses()
train(model_args, training_args)
# Model Constants
IGNORE_ID = -100
IMAGE_TOKEN = "<image>"
IMAGE_TOKEN_ID = -200
VIDEO_TOKEN = "<video>"
VIDEO_TOKEN_ID = -201
VISUAL_ATOM_ID = -300
INDICATOR_IDS = [-301, -302, -303, -304]
# Log & Print
BEGIN_LINE = '========================************========================'
END_LINE = '------------------------------------------------------------'
import os
from importlib import import_module
import torch
def rank0_print(*args):
if int(os.getenv("LOCAL_PROCESS_RANK", os.getenv("LOCAL_RANK", 0))) == 0:
print(*args)
def rankN_print(*args):
rank = int(os.getenv("LOCAL_PROCESS_RANK", os.getenv("LOCAL_RANK", 0)))
print(f'<R{rank}>', *args)
def smart_unit(num):
if num / 1.0e9 >= 1:
return f'{num / 1.0e9:.2f}B'
else:
return f'{num / 1.0e6:.2f}M'
def replace_torch_load_with_weights_only_false():
original_torch_load = torch.load
def torch_load_with_weights_only_false(*args, **kwargs):
kwargs["weights_only"] = False
return original_torch_load(*args, **kwargs)
# 替换 torch.load
torch.load = torch_load_with_weights_only_false
import torch
import requests
from PIL import Image
from transformers import AutoModelForCausalLM
MODEL_PATH = "AIDC-AI/Ovis2.5-2B"
# Thinking mode & budget
enable_thinking = True
enable_thinking_budget = True # Only effective if enable_thinking is True.
# Total tokens for thinking + answer. Ensure: max_new_tokens > thinking_budget + 25
max_new_tokens = 3072
thinking_budget = 2048
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
messages = [{
"role": "user",
"content": [
{"type": "image", "image": Image.open("./doc/TIlymOb86R6_Mez3bpmcB.png")},
{"type": "text", "text": "Calculate the sum of the numbers in the middle box in figure (c)."},
],
}]
input_ids, pixel_values, grid_thws = model.preprocess_inputs(
messages=messages,
add_generation_prompt=True,
enable_thinking=enable_thinking
)
input_ids = input_ids.cuda()
pixel_values = pixel_values.cuda() if pixel_values is not None else None
grid_thws = grid_thws.cuda() if grid_thws is not None else None
outputs = model.generate(
inputs=input_ids,
pixel_values=pixel_values,
grid_thws=grid_thws,
enable_thinking=enable_thinking,
enable_thinking_budget=enable_thinking_budget,
max_new_tokens=max_new_tokens,
thinking_budget=thinking_budget,
)
response = model.text_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
# MDP3: A Training-free Approach for List-wise Frame Selection in Video-LLMs
# https://arxiv.org/abs/2501.02885
import copy
import time
from contextlib import contextmanager
import torch
from torch import nn
from transformers import AutoProcessor, AutoModel
@contextmanager
def timer(hint=""):
start = time.perf_counter()
yield
end = time.perf_counter()
print(f"{hint} runtime: {end - start:.4f} s")
INF = 0x7fffffff
class VisualEncoder():
def __init__(self, model_path, device="cuda"):
self.device = device
self.model_path = model_path
self.model = AutoModel.from_pretrained(self.model_path)
self.model = self.model.to(device=self.device)
self.processor = AutoProcessor.from_pretrained(self.model_path)
def __call__(self, images, texts, clear_prompt=False):
if clear_prompt:
texts = self.clear_prompt(copy.deepcopy(texts))
with timer("visual processor"):
inputs = self.processor(
text=texts, images=images, padding="max_length", return_tensors="pt").to(self.model.device)
stride_num = (int(inputs["input_ids"].shape[-1]) + 63) // 64
stride = (inputs["input_ids"].shape[-1] + stride_num - 1) // stride_num
input_id_heads, input_id_tails = [], []
l, r = 0, inputs["input_ids"].shape[-1]
while l < r:
input_id_heads.append(inputs["input_ids"][:, l:l + stride])
l += stride
if l < r:
input_id_tails.append(inputs["input_ids"][:, r - stride:r])
r -= stride
input_ids = input_id_heads + input_id_tails[::-1]
input_ids = torch.cat(input_ids)
with timer("extract embeds"):
with torch.no_grad():
with torch.autocast(self.device):
outputs = self.model(
input_ids, pixel_values=inputs["pixel_values"])
image_embeds = outputs.image_embeds
text_embeds = outputs.text_embeds
return image_embeds, text_embeds.mean(dim=0, keepdim=True)
def clear_prompt(self, prompt):
heads = [
"Select the best answer to the following multiple-choice question based on the video and the subtitles. Respond with only the letter (A, B, C, or D) of the correct option.",
"Select the best answer to the following multiple-choice question based on the video. Respond with only the letter (A, B, C, or D) of the correct option.",
"Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.",
"Carefully watch this video and pay attention to every detail. Based on your observations, select the best option that accurately addresses the question."
]
tails = [
"Answer with the option's letter from the given choices directly.",
"The best answer is:",
"Answer the question using a single word or phrase.",
"Only give the best option.",
"Best option: (",
"Please directly give the best option:"
]
for head in heads:
prompt = prompt.split(head)[-1]
for tail in tails:
prompt = prompt.split(tail)[0]
prompt = prompt.strip()
return prompt
class MDP3:
def __init__(self, n_selection=16, visual_encoder_name_or_path="google/siglip-so400m-patch14-384", device="cuda"):
super().__init__()
self.n_selection = n_selection
self.lamda = 0.2
self.segment_size = -1
self.condition_size = 1
self.kernel = MultiGaussianKernel(
alphas=[2 ** k for k in list(range(-3, 2))])
self.ve = VisualEncoder(model_path=visual_encoder_name_or_path, device=device)
def __call__(self, conversations, frames, clear_prompt=True):
if len(frames) <= self.n_selection:
return conversations, frames
prompt = "\n".join([x["value"] for x in conversations if x["from"] == "human"])
with timer("** VLMs Process & Extract"):
frame_embeds, text_embeds = self.ve(frames, prompt, clear_prompt)
with timer("Select Frames"):
with torch.no_grad():
selected_idx = self._select_frames_fast(frame_embeds, text_embeds)
# clean conversion
ret_conversions = []
img_cnt = 0
for conv in conversations:
prompt_clips = conv["value"].split("<image>")
conv_v = []
for idx, clip in enumerate(prompt_clips):
if clip != "\n":
conv_v.append(clip)
if img_cnt in selected_idx and idx < len(prompt_clips) - 1:
if clip == "\n":
conv_v.append(clip)
conv_v.append("<image>")
img_cnt += 1
ret_conversions.append({"from": conv["from"], "value": ''.join(conv_v)})
ret_frames = [frames[idx] for idx in selected_idx]
return ret_conversions, ret_frames
def cal_obj(self, selected_images_embeds, text_embed):
kernel_matrix = self.kernel(
torch.cat([text_embed, selected_images_embeds]))
r, S_matrix = kernel_matrix[0:1, 1:], kernel_matrix[1:, 1:]
ret_score = (1. / self.lamda * 2 * torch.log(r).sum()) + \
torch.linalg.slogdet(S_matrix).logabsdet
return ret_score
def _select_frames(self, image_embeds, text_embeds):
# initializing dynamic programing
N_image = len(image_embeds)
if self.segment_size <= 0:
segment_size = N_image
else:
segment_size = self.segment_size
segment_num = (N_image + segment_size - 1) // segment_size
dp = [[0.] + [-INF] * self.n_selection for _ in range(segment_num + 1)]
trace = [[[] for _ in range(self.n_selection + 1)]
for _ in range(segment_num + 1)]
for seg_idx in range(1, segment_num + 1):
for selected_num in range(1, min(self.n_selection, seg_idx * segment_size) + 1):
for to_select_num in range(0, min(selected_num, segment_size) + 1):
cur_score, cur_trace = self.seqdpp_select(
text_embeds=text_embeds,
image_embeds=image_embeds,
conditional_index=trace[seg_idx - 1][selected_num - to_select_num][
-min(self.condition_size,
len(trace[seg_idx - 1][selected_num - to_select_num])):],
candidate_index=range(
(seg_idx - 1) * segment_size, seg_idx * segment_size),
to_select_num=to_select_num
)
cur_score = dp[seg_idx - 1][selected_num -
to_select_num] + cur_score
cur_trace = trace[
seg_idx - 1][selected_num - to_select_num] + cur_trace
if cur_score > dp[seg_idx][selected_num]:
dp[seg_idx][selected_num] = cur_score
trace[seg_idx][selected_num] = cur_trace
return trace[segment_num][self.n_selection]
def _select_frames_fast(self, image_embeds, text_embeds):
# initializing dynamic programing
N_image = len(image_embeds)
if self.segment_size <= 0:
segment_size = N_image
else:
segment_size = self.segment_size
segment_num = (N_image + segment_size - 1) // segment_size
dp = [[0.] + [-INF] * self.n_selection for _ in range(segment_num + 1)]
trace = [[[] for _ in range(self.n_selection + 1)]
for _ in range(segment_num + 1)]
for seg_idx in range(1, segment_num + 1):
candidate_index = range(
(seg_idx - 1) * segment_size, seg_idx * segment_size)
candidate_embeds = [image_embeds[i] for i in candidate_index]
sim_matrix = self.kernel(torch.stack(candidate_embeds))
for start_selected_num in range(0, min(self.n_selection, (seg_idx - 1) * segment_size) + 1):
conditional_index = trace[seg_idx - 1][start_selected_num][
-min(self.condition_size, len(trace[seg_idx - 1][start_selected_num])):]
offset = len(conditional_index)
additional_embeds = [text_embeds[
0].reshape(-1)] + [image_embeds[i] for i in conditional_index]
additional = self.kernel(
torch.stack(additional_embeds),
torch.stack(additional_embeds + candidate_embeds)
)
total_matrix = torch.cat([
additional, # [add, 32+add]
torch.cat([
additional[:, -len(sim_matrix):].T, # [32, add]
sim_matrix # [32, 32]
], dim=1) # [32, add + 32]
], dim=0) # [add+32, add+32]
max_selection = min(self.n_selection -
start_selected_num, segment_size)
cur_scores, cur_traces = self.seqdpp_select_super_fast(
total_matrix, offset, max_selection)
for to_select_num, (cur_score, cur_trace) in enumerate(zip(cur_scores, cur_traces)):
cur_trace = [i + int((seg_idx - 1) * segment_size)
for i in cur_trace]
cur_score = dp[seg_idx - 1][start_selected_num] + cur_score
cur_trace = trace[
seg_idx - 1][start_selected_num] + cur_trace
if cur_score > dp[seg_idx][start_selected_num + to_select_num]:
dp[seg_idx][start_selected_num + to_select_num] = cur_score
trace[seg_idx][start_selected_num +
to_select_num] = cur_trace
return trace[segment_num][self.n_selection]
def seqdpp_select(self, text_embeds, image_embeds, conditional_index, candidate_index, to_select_num):
if to_select_num == 0:
return 0.0, []
conditional_embeds = [image_embeds[i] for i in conditional_index]
cur_trace = []
U_matrix = self.kernel(torch.stack(
conditional_embeds + [image_embeds[i] for i in candidate_index]))
I = torch.diag(
torch.tensor([0.] * len(conditional_index) + [1.] *
len(candidate_index), device=U_matrix.device)
)
obj_values = -torch.linalg.slogdet(U_matrix + I).logabsdet
while len(cur_trace) < to_select_num:
max_obj_gain = -INF
cur_selected_idx = -1
for i in candidate_index:
if i in cur_trace:
continue
cur_obj = self.cal_obj(
selected_images_embeds=torch.stack(
conditional_embeds + [image_embeds[j] for j in cur_trace + [i]]),
text_embed=text_embeds[0].reshape(1, -1)
)
cur_obj_gain = cur_obj - obj_values
if cur_obj_gain > max_obj_gain:
max_obj_gain = cur_obj_gain
cur_selected_idx = i
cur_trace.append(cur_selected_idx)
obj_values += max_obj_gain
cur_trace = sorted(cur_trace)
return obj_values if len(cur_trace) > 0 else 0.0, cur_trace
def seqdpp_select_fast(self, total_matrix, offset, to_select_num):
if to_select_num == 0:
return 0.0, []
cur_trace = []
obj_values = 0.0
r, S_matrix = total_matrix[0:1, 1:], total_matrix[1:, 1:]
candidate_index = range(len(S_matrix) - offset)
while len(cur_trace) < to_select_num:
max_obj_gain = -INF
cur_selected_idx = -1
for i in candidate_index:
if i in cur_trace:
continue
selected_idx = list(range(offset)) + \
[j + offset for j in cur_trace + [i]]
cur_S_matrix = S_matrix[selected_idx][:, selected_idx]
cur_obj = (1. / self.lamda * 2 * torch.log(
r[:, selected_idx]).sum()) + torch.linalg.slogdet(cur_S_matrix).logabsdet
cur_obj_gain = cur_obj - obj_values
if cur_obj_gain > max_obj_gain:
max_obj_gain = cur_obj_gain
cur_selected_idx = i
cur_trace.append(cur_selected_idx)
obj_values += max_obj_gain
cur_trace = sorted(cur_trace)
return obj_values if len(cur_trace) > 0 else 0.0, cur_trace
def seqdpp_select_super_fast(self, total_matrix, offset, to_select_num):
if to_select_num == 0:
return [0.0], [[]]
cur_trace = []
ret_scores = [0.0]
r, S_matrix = total_matrix[0:1, 1:], total_matrix[1:, 1:]
candidate_index = list(range(len(S_matrix) - offset))
conditional_idx = list(range(offset))
L = None
if len(conditional_idx) > 0:
L = torch.linalg.cholesky(
S_matrix[conditional_idx][:, conditional_idx])
while len(cur_trace) < to_select_num:
max_obj = -INF
cur_selected_idx = -1
better_L = None
for i in candidate_index:
if i in cur_trace:
continue
cur_idx = i + offset
selected_idx = conditional_idx + \
[j + offset for j in cur_trace] + [cur_idx]
if L is None:
cur_sim_v = S_matrix[selected_idx][:, selected_idx]
cur_L = torch.sqrt(cur_sim_v).reshape(1, 1)
logdet = cur_sim_v.clone().log()
else:
cur_sim_v = S_matrix[cur_idx:cur_idx + 1][:, selected_idx]
cur_L, logdet = self.cholesky_update_determinant(
L, cur_sim_v)
cur_obj = 1. / self.lamda * 2 * \
torch.log(r[:, selected_idx]).sum() + logdet
if cur_obj > max_obj or cur_selected_idx == -1:
max_obj = cur_obj
cur_selected_idx = i
better_L = cur_L
ret_scores.append(max_obj.clone())
cur_trace.append(cur_selected_idx)
L = better_L
ret_traces = [sorted(cur_trace[:j]) for j in range(len(cur_trace) + 1)]
return ret_scores, ret_traces
def cholesky_update_determinant(self, L, v):
n = L.shape[0]
v = v.view(-1, 1)
v_projected = torch.linalg.solve_triangular(L, v[:n], upper=False)
new_diag_element = torch.sqrt(torch.abs(v[-1] - v_projected.T @ v_projected))
new_row = torch.cat((v_projected.flatten(), new_diag_element.view(1)))
new_L = torch.zeros((n + 1, n + 1), dtype=L.dtype, device=L.device)
new_L[:n, :n] = L
new_L[n, :n] = new_row[:-1]
new_L[n, n] = new_diag_element
new_diag = torch.diag(new_L)
new_logdet = 2 * torch.log(new_diag).sum()
return new_L, new_logdet
class GaussianKernel(nn.Module):
def __init__(self, alpha=1.):
super(GaussianKernel, self).__init__()
self.alpha = alpha
def forward(self, X: torch.Tensor) -> torch.Tensor:
l2_distance_square = ((X.unsqueeze(1) - X.unsqueeze(0)) ** 2).sum(2)
return torch.exp(-l2_distance_square / (2 * self.alpha))
class MultiGaussianKernel(nn.Module):
def __init__(self, alphas=None):
super(MultiGaussianKernel, self).__init__()
if alphas is None:
alphas = [2 ** k for k in list(range(-3, 2))]
self.alphas = alphas
def forward(self, X: torch.Tensor, Y: torch.tensor = None) -> int:
Y = X.unsqueeze(0) if Y is None else Y.unsqueeze(0)
X = X.unsqueeze(1)
l2_distance_square = ((X - Y) ** 2).sum(2)
return sum([torch.exp(-l2_distance_square / (2 * alpha)) for alpha in self.alphas])
torch==2.4.0
transformers==4.51.3
tokenizers==0.21.1
sentencepiece==0.1.99
pyarrow==18.0.0
accelerate==1.1.0
pydantic_core==2.20.1
markdown2[all]
numpy==1.24.3
scikit-learn==1.2.2
requests
httpx
uvicorn
fastapi==0.112.4
einops==0.6.1
einops-exts==0.0.4
timm==1.0.11
tiktoken
transformers_stream_generator==0.0.4
scipy
pandas
torchaudio
xformers
pillow==10.3.0
deepspeed==0.15.4
pysubs2==1.7.2
moviepy==1.0.3
gradio
\ No newline at end of file
#!/bin/bash
set -e
# Experiment name is taken from script filename
EXPNAME="${0##*/}"
EXPNAME="${EXPNAME%.sh}"
OVIS_CKPT_DIR="AIDC-AI/Ovis2.5-9B"
data_name="geometry3k_local"
CMDARG="--deepspeed scripts/zero_configs/zero1_cp.json \
--stage 3 \
--data_info_version ovis2_5_sft_datainfo \
--data_name ${data_name} \
--data_type conversation \
--data_seed 5171 \
--accepts_loss_kwargs True \
--ovis_pretrained_path ${OVIS_CKPT_DIR} \
--attn_implementation flash_attention_2 \
--single_image_min_pixels 200704 \
--single_image_max_pixels 3211264 \
--min_frames 10 \
--max_frames 10 \
--train_modules all \
--multimodal_max_length 6000 \
--text_max_length 6000 \
--output_dir path/to/checkpoints/$EXPNAME \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 16 \
--num_train_epochs 1 \
--eval_strategy no \
--save_strategy steps \
--save_steps 0.4 \
--save_total_limit 10 \
--learning_rate 2e-5 \
--max_grad_norm 1.0 \
--weight_decay 0. \
--warmup_ratio 0.1 \
--lr_scheduler_type cosine \
--logging_steps 1 \
--tf32 True \
--bf16 True \
--dataloader_num_workers 8 \
--dataloader_drop_last True \
--dataloader_persistent_workers True \
--gradient_checkpointing True \
--report_to none \
--run_name $EXPNAME"
echo "Training arguments:"
echo "$CMDARG"
# Run with torchrun
torchrun --nproc_per_node=2 ovis/train/train.py $CMDARG
{
"compile": {
"enabled": true,
"backend": "inductor"
},
"fp16": {
"enabled": false
},
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_optimization": {
"stage": 0
}
}
{
"compile": {
"enabled": true,
"backend": "inductor"
},
"fp16": {
"enabled": false
},
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_optimization": {
"stage": 1
}
}
{
"compile": {
"enabled": true,
"backend": "inductor"
},
"fp16": {
"enabled": false
},
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true
}
}
{
"compile": {
"enabled": true,
"backend": "inductor"
},
"fp16": {
"enabled": false
},
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
}
}
\ No newline at end of file
from setuptools import setup, find_packages
import os
here = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(here, 'requirements.txt'), 'r', encoding='utf-8') as f:
requirements = f.read().splitlines()
setup(
name='ovis',
version='2.5.0',
packages=find_packages(where='.', exclude=('tests', 'docs')),
install_requires=requirements
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment