Commit 0063a668 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
import os
import json
import cv2
import csv
import io
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
import torchvision
from cotracker.utils.visualizer import Visualizer
from data.utils.visual_trace import visual_trace
from data.utils.som_tom import som_prompting, tom_prompting
device = 'cuda'
grid_size = 15
# cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").to(device)
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline", source='local').to(device)
visual_trace_folder = "./tools/som_tom/videos"
vis = Visualizer(save_dir=visual_trace_folder, pad_value=0, linewidth=3, tracks_leave_trace=-1)
trace = visual_trace(linewidth=3)
def som_tom(video, pred_tracks, pred_visibility, item={}, epsilon=2):
# only keep points that are visible at at least half steps
valid_idx = pred_visibility[0].sum(0) > 0.5*pred_tracks.shape[1]
pred_tracks = pred_tracks[:, :, valid_idx]
pred_visibility = pred_visibility[:, :, valid_idx]
# Alg2 L2-4: Remove camera motion
# calculate the trajectory lenght for pred_tracks
trace_lengths = trace.visual_trace_length(pred_tracks, pred_visibility, (1, 1)).squeeze(0)
# if 80% of the pred_tracks_length is larger than 2, then there is camera motion
camera_motion = (trace_lengths > 0.5).sum() > 0.8*trace_lengths.size(0)
start_pos = pred_tracks[:, 0][0]
reference_pts_np = start_pos.cpu().numpy().reshape(-1, 2)
if camera_motion:
# remove camera motion using homography transformation
try:
future_pts_transformed = []
for k in range(1, pred_tracks.shape[1]):
future_pts = pred_tracks[:, k][0]
future_pts_np = future_pts.cpu().numpy().reshape(-1, 2)
try:
(H, status) = cv2.findHomography(future_pts_np, reference_pts_np, cv2.RANSAC, 4.0)
except Exception as e:
continue
future_pts_np_transformed = cv2.perspectiveTransform(future_pts_np.reshape(1, -1, 2), H).reshape(-1, 2)
future_pts_transformed_k = torch.tensor(future_pts_np_transformed, dtype=torch.float32)
future_pts_transformed.append(future_pts_transformed_k)
pred_tracks = torch.stack([start_pos] + future_pts_transformed, dim=0).unsqueeze(0)
except Exception as e:
pass
# Alg2 L5: Find the positive traces and negative traces
pos_tracks = pred_tracks[:, :, trace_lengths > epsilon]
pos_visibility = pred_visibility[:, :, trace_lengths > epsilon]
neg_tracks = pred_tracks[:, :, trace_lengths <= epsilon]
neg_visibility = pred_visibility[:, :, trace_lengths <= epsilon]
# Alg2 L6-7: Clustering for positive and negative traces
num_clusters_pos = torch.randint(2, 6, (1,)).item()
pos_sampled_ids = trace.cluster_traces_kmeans(pos_tracks, n_clusters=num_clusters_pos, positive=True)
pos_tracks = pos_tracks[:, :, pos_sampled_ids.bool()]
pos_visibility = pos_visibility[:, :, pos_sampled_ids.bool()]
# clustering for negative traces
num_clusters_neg = torch.randint(6, 15, (1,)).item()
neg_sampled_ids = trace.cluster_traces_kmeans(neg_tracks, n_clusters=num_clusters_neg)
neg_tracks = neg_tracks[:, :, neg_sampled_ids.bool()]
image = video[0][0].numpy().transpose(1, 2, 0).astype(np.uint8)
image = Image.fromarray(image).convert("RGB")
# Alg2 L8: Apply som on the first frame
image, pos_traces_to_mark, neg_traces_to_mark, pos_mark_ids, neg_mark_ids, all_idx = \
som_prompting(image, pos_tracks, neg_tracks, draw_som_positive=True, draw_som_negative=True)
# visualize the traces
images = [image] * pos_tracks.shape[1]
video = torch.stack([torchvision.transforms.ToTensor()(img) for img in images])[None].float()*255
_ = vis.visualize(video, pos_tracks, pos_visibility, filename=f"som_tom")
video_path = "assets/videos/tom_orig_sample.mp4"
# load video
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, 20)
# get number of frames in cap
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
images = []
while True:
ret, frame = cap.read()
# if reach stop frame then break
if not ret:
break
# convert to RGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
images.append(frame)
cap.release()
images = [Image.fromarray(img) for img in images]
# resize images to height=512
images = [img.resize((int(img.width * 512 / img.height), 512)) for img in images]
video = torch.stack([torchvision.transforms.ToTensor()(img) for img in images])[None].float()*255
video = video.to(device)
# Alg2 L1: Extract visual trace
pred_tracks, pred_visibility = cotracker(video, grid_size=grid_size) # B T N 2, B T N 1
_ = vis.visualize(
video.cpu(),
pred_tracks,
pred_visibility,
query_frame=0,
filename='orig_trace',
)
som_tom(video.cpu(), pred_tracks, pred_visibility)
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List
import torch
import deepspeed
import glob
import transformers
import tokenizers
import random
import re
from magma.image_processing_magma import MagmaImageProcessor
from magma.processing_magma import MagmaProcessor
from magma.modeling_magma import MagmaForCausalLM
from magma.configuration_magma import MagmaConfig
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
BitsAndBytesConfig,
Trainer,
TrainingArguments,
)
from transformers import AutoTokenizer, AutoConfig
from transformers.trainer import get_model_param_count
from trainer import MagmaTrainer
from data import *
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
from packaging import version
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="microsoft/Magma-8B")
version: Optional[str] = field(default="magma_instruct")
freeze_backbone: bool = field(default=False)
tune_mm_mlp_adapter: bool = field(default=False)
vision_tower: Optional[str] = field(default=None)
vision_tower_ckpt: Optional[str] = field(default=None)
img_anyres_strategy: Optional[str] = field(default='crop')
proj_vis_to_txt_tokens: bool = field(default=False)
img_size: Optional[int] = field(default=640) # default to the last layer
vision_backbone: Optional[str] = field(default="convnextlarge")
tune_vision_tokenizer: Optional[str] = field(default='none')
mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
mm_projector_type: Optional[str] = field(default='linear')
mm_use_trace_start_end: bool = field(default=False)
mm_use_trace_speed: bool = field(default=False)
mm_use_image_start_end: bool = field(default=False)
mm_use_image_history: bool = field(default=False)
mm_use_som_tom: bool = field(default=True)
mm_use_som_tom_orig_img: bool = field(default=False)
spatial_quant_size: Optional[int] = field(default=256)
remove_static_trace_pts: bool = field(default=False)
mm_use_im_patch_token: bool = field(default=True)
mm_vision_select_feature: Optional[str] = field(default="patch")
flash_attn_2_enabled: bool = False
task: Optional[str] = field(default="agent")
@dataclass
class DataArguments:
data_path: str = field(default=None,
metadata={"help": "Path to the training data."})
lazy_preprocess: bool = False
is_multimodal: bool = False
data_format: str = "llava"
image_folder: Optional[str] = field(default=None)
image_aspect_ratio: str = 'square'
max_num_crops: int = 25
add_im_loss: bool = False
training_size: str = 'default'
show_trace: bool = False
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
remove_unused_columns: bool = field(default=False)
freeze_mm_mlp_adapter: bool = field(default=False)
mpt_attn_impl: Optional[str] = field(default="triton")
model_max_length: int = field(
default=512,
metadata={
"help":
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
double_quant: bool = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."}
)
quant_type: str = field(
default="nf4",
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
)
bits: int = field(
default=16,
metadata={"help": "How many bits to use."}
)
lora_enable: bool = False
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
min_lr_rate: Optional[float] = None
mm_projector_lr: Optional[float] = None
vision_tokenizer_lr: Optional[float] = None
group_by_modality_length: bool = field(default=False)
local_run: bool = False
max_grad_norm: float = 1.0
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
return list(lora_module_names)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
output_dir: str):
"""Collects the state dict and dump to disk."""
if getattr(trainer.args, "tune_mm_mlp_adapter", False):
# Only save Adapter
keys_to_match = ['mm_projector']
if getattr(trainer.args, "use_im_start_end", False):
keys_to_match.extend(['embed_tokens', 'embed_in'])
if getattr(trainer.args, "tune_vision_tokenizer", 'none') == "posembed":
keys_to_match.extend(['posembed'])
elif getattr(trainer.args, "tune_vision_tokenizer", 'none') == "decoder":
keys_to_match.extend(['sem_seg_head.predictor'])
elif getattr(trainer.args, "tune_vision_tokenizer", 'none') == "all":
keys_to_match.extend(['vision_tower'])
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
trainer.model.config.save_pretrained(output_dir)
current_folder = output_dir.split('/')[-1]
parent_folder = os.path.dirname(output_dir)
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
if current_folder.startswith('checkpoint-'):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
else:
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
return
if trainer.deepspeed:
torch.cuda.synchronize()
trainer.save_model(output_dir)
return
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {
key: value.cpu()
for key, value in state_dict.items()
}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def smart_tokenizer_and_embedding_resize(
special_tokens_dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
if isinstance(special_tokens_dict, list):
num_new_tokens = tokenizer.add_tokens(special_tokens_dict, special_tokens=True)
else:
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
new_vocab_size = len(tokenizer)
# Update base model and current model config
if hasattr(model.config, "text_config"):
model.config.text_config.vocab_size = new_vocab_size
else:
model.config.vocab_size = new_vocab_size
model.vocab_size = new_vocab_size
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def _tokenize_fn(strings: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
) for text in strings
]
input_ids = labels = [
tokenized.input_ids[0] for tokenized in tokenized_list
]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def make_supervised_data_module(processor: MagmaProcessor,
data_args,
training_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = build_joint_dataset(
processor=processor,
data_path=data_args.data_path,
data_args=data_args
)
if training_args.evaluation_strategy != 'no':
val_dataset = build_joint_dataset(
processor=processor,
data_path=data_args.data_path,
data_args=data_args,
is_eval=True
)
else:
val_dataset = None
data_collator = DataCollatorForSupervisedDataset(processor=processor)
return dict(train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator)
def train():
global local_rank
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
local_rank = training_args.local_rank
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
if training_args.min_lr_rate is not None:
training_args.lr_scheduler_kwargs = {'min_lr_rate': training_args.min_lr_rate}
bnb_model_from_pretrained_args = {}
if training_args.bits in [4, 8]:
from transformers import BitsAndBytesConfig
bnb_model_from_pretrained_args.update(dict(
device_map={"": training_args.device},
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
quantization_config=BitsAndBytesConfig(
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
llm_int8_skip_modules=["mm_projector"],
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
)
))
if 'magma' in model_args.model_name_or_path.lower():
model = MagmaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation="flash_attention_2" if model_args.flash_attn_2_enabled else None,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args
)
magma_processor = MagmaProcessor.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True
)
model.config.tokenizer_vocab_size = magma_processor.tokenizer.vocab_size
else:
vision_config = {
"img_size": model_args.img_size,
"anyres_strategy": model_args.img_anyres_strategy,
"vision_backbone": model_args.vision_backbone,
"vision_tower": model_args.vision_tower,
"vision_tower_ckpt": model_args.vision_tower_ckpt,
"mm_vision_select_layer": model_args.mm_vision_select_layer,
"mm_vision_select_feature": model_args.mm_vision_select_feature,
"pretrain_mm_mlp_adapter": model_args.pretrain_mm_mlp_adapter,
"mm_projector_type": model_args.mm_projector_type,
"proj_vis_to_txt_tokens": model_args.proj_vis_to_txt_tokens,
"mm_use_im_patch_token": model_args.mm_use_im_patch_token,
"vision_feature_layer": "clip_vis_dense",
"use_cache": False,
}
text_config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True
)
magma_config = MagmaConfig(
vision_config=vision_config,
text_config=text_config,
)
model = MagmaForCausalLM(magma_config)
# reload language model
model.language_model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation="flash_attention_2" if model_args.flash_attn_2_enabled else None,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
trust_remote_code=True,
**bnb_model_from_pretrained_args
)
# reload vision encoder
from open_clip.pretrained import download_pretrained_from_hf
if vision_config['vision_tower'] == 'convnext':
model_id = 'laion/CLIP-convnext_large-laion2B-s34B-b82K-augreg'
else:
model_id = 'laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg'
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=None)
model.load_special_module_from_ckpt(checkpoint_path, torch_dtype=(torch.bfloat16 if training_args.bf16 else None))
# load 'magma/default_preprocessor_config.json' if it exists
if os.path.exists('magma/default_preprocessor_config.json'):
with open('magma/default_preprocessor_config.json') as f:
preprocessor_config = json.load(f)
else:
preprocessor_config = {}
image_processor = MagmaImageProcessor(**preprocessor_config)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
magma_processor = MagmaProcessor(image_processor=image_processor, tokenizer=tokenizer)
smart_tokenizer_and_embedding_resize(
special_tokens_dict=["<image>"],
tokenizer=magma_processor.tokenizer,
model=model,
)
# if tokenizer does not have pad_token, add it
if magma_processor.tokenizer.pad_token_id is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict={'pad_token': '<pad>'},
tokenizer=magma_processor.tokenizer,
model=model,
)
model.config.image_token_index = tokenizer.convert_tokens_to_ids("<image>")
model.config.tokenizer_vocab_size = magma_processor.tokenizer.vocab_size
model = model.to(training_args.device)
magma_processor.tokenizer.model_max_length = training_args.model_max_length
magma_processor.image_processor.base_img_size = model_args.img_size
magma_processor.image_processor.anyres_strategy = model_args.img_anyres_strategy
if model_args.mm_use_trace_start_end:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=["<trace_start>", "<trace_end>"],
tokenizer=magma_processor.tokenizer,
model=model,
)
if model_args.mm_use_image_start_end:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=["<image_start>", "<image_end>"],
tokenizer=magma_processor.tokenizer,
model=model,
)
# we add an <action> token as the place holder for the action
smart_tokenizer_and_embedding_resize(
special_tokens_dict=["<action>"],
tokenizer=magma_processor.tokenizer,
model=model,
)
if model_args.freeze_backbone:
model.requires_grad_(False)
if training_args.bits in [4, 8]:
from peft import prepare_model_for_kbit_training
model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
if training_args.gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if training_args.lora_enable:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
target_modules=find_all_linear_names(model),
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias,
task_type="CAUSAL_LM",
)
if training_args.bits == 16:
if training_args.bf16:
model.to(torch.bfloat16)
if training_args.fp16:
model.to(torch.float16)
rank0_print("Adding LoRA adapters...")
model = get_peft_model(model, lora_config)
if model_args.tune_mm_mlp_adapter:
model.requires_grad_(False)
for p in model.multi_modal_projector.parameters():
p.requires_grad = True
if training_args.freeze_mm_mlp_adapter:
for p in model.multi_modal_projector.parameters():
p.requires_grad = False
if model_args.tune_vision_tokenizer == "none":
for name, p in model.vision_tower.named_parameters():
p.requires_grad = False
total_params = get_model_param_count(model, trainable_only=True)
rank0_print(f"Total trainable parameters: {total_params}")
if training_args.bits in [4, 8]:
model.multi_modal_projector.to(dtype=compute_dtype, device=training_args.device)
from peft.tuners.lora import LoraLayer
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if training_args.bf16:
module = module.to(torch.bfloat16)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if training_args.bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
data_args.mm_use_trace_start_end = model_args.mm_use_trace_start_end
data_args.mm_use_trace_speed = model_args.mm_use_trace_speed
data_args.mm_use_image_start_end = model_args.mm_use_image_start_end
data_args.mm_use_image_history = model_args.mm_use_image_history
data_args.mm_use_som_tom = model_args.mm_use_som_tom
data_args.mm_use_som_tom_orig_img = model_args.mm_use_som_tom_orig_img
data_args.remove_static_trace_pts = model_args.remove_static_trace_pts
data_args.spatial_quant_size = model_args.spatial_quant_size
data_args.version = model_args.version
data_args.local_run = training_args.local_run
data_args.task = model_args.task
model.config.mm_use_trace_start_end = model_args.mm_use_trace_start_end
model.config.mm_use_trace_speed = model_args.mm_use_trace_speed
model.config.mm_use_image_start_end = model_args.mm_use_image_start_end
model.config.mm_use_image_history = model_args.mm_use_image_history
model.config.remove_static_trace_pts = model_args.remove_static_trace_pts
model.config.mm_use_som_tom = model_args.mm_use_som_tom
model.config.mm_use_som_tom_orig_img = model_args.mm_use_som_tom_orig_img
model.config.spatial_quant_size = model_args.spatial_quant_size
model.config.img_size = model_args.img_size
model.config.use_cache = False
model.config.vision_config['img_anyres_strategy'] = model_args.img_anyres_strategy
data_module = make_supervised_data_module(processor=magma_processor,
data_args=data_args,
training_args=training_args)
trainer = MagmaTrainer(model=model,
tokenizer=magma_processor.tokenizer,
args=training_args,
**data_module)
# print training_args
rank0_print(training_args)
rank0_print(model_args)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
print("Resuming from checkpoint...")
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
model.config.use_cache = True
if training_args.lora_enable:
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), training_args.lora_bias
)
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
model.named_parameters()
)
if training_args.local_rank == 0 or training_args.local_rank == -1:
model.config.save_pretrained(training_args.output_dir)
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
else:
safe_save_model_for_hf_trainer(trainer=trainer,
output_dir=training_args.output_dir)
# save image_processor config for rank 0
if training_args.local_rank == 0 or training_args.local_rank == -1:
magma_processor.image_processor.save_pretrained(training_args.output_dir)
if __name__ == "__main__":
train()
from .trainer import MagmaTrainer
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 1,
"overlap_comm": false,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
}
}
\ No newline at end of file
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 2,
"overlap_comm": false,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
}
}
\ No newline at end of file
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 3,
"overlap_comm": false,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
}
}
\ No newline at end of file
import os
import torch
from torch.utils.data import Sampler
from torch.cuda import synchronize
from transformers import Trainer
from transformers.trainer import (
is_sagemaker_mp_enabled,
get_parameter_names,
has_length,
ALL_LAYERNORM_LAYERS,
logger,
)
from typing import List, Optional
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
print(name, 'no ignore status')
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
return to_return
def split_to_even_chunks(indices, lengths, num_chunks):
"""
Split a list of indices into `chunks` chunks of roughly equal lengths.
"""
if len(indices) % num_chunks != 0:
return [indices[i::num_chunks] for i in range(num_chunks)]
num_indices_per_chunk = len(indices) // num_chunks
chunks = [[] for _ in range(num_chunks)]
chunks_lengths = [0 for _ in range(num_chunks)]
for index in indices:
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
chunks[shortest_chunk].append(index)
chunks_lengths[shortest_chunk] += lengths[index]
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
chunks_lengths[shortest_chunk] = float("inf")
return chunks
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
assert all(l != 0 for l in lengths), "Should not have zero length."
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
# all samples are in the same modality
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
megabatch_size = world_size * batch_size
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
last_mm = mm_megabatches[-1]
last_lang = lang_megabatches[-1]
additional_batch = last_mm + last_lang
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
megabatches = [megabatches[i] for i in megabatch_indices]
if len(additional_batch) > 0:
megabatches.append(sorted(additional_batch))
return [i for megabatch in megabatches for i in megabatch]
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
# indices = torch.arange(len(lengths))
megabatch_size = world_size * batch_size
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
return [i for megabatch in megabatches for batch in megabatch for i in batch]
class LengthGroupedSampler(Sampler):
r"""
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
keeping a bit of randomness.
"""
def __init__(
self,
batch_size: int,
world_size: int,
lengths: Optional[List[int]] = None,
generator=None,
group_by_modality: bool = False,
):
if lengths is None:
raise ValueError("Lengths must be provided.")
self.batch_size = batch_size
self.world_size = world_size
# gather self lengths from all processes
# if self.world_size > 1:
# # gather the size of lengths from all processes
# sizes = torch.tensor([len(lengths)], device=torch.device("cuda"))
# # take minimum length
# torch.distributed.all_reduce(sizes, op=torch.distributed.ReduceOp.MIN)
# min_size = sizes.item()
# # trim lengths to the minimum size
# lengths = lengths[:min_size]
# # append lengths from all processes
# all_lengths = [torch.zeros_like(lengths) for _ in range(self.world_size)]
# torch.distributed.all_gather(all_lengths, torch.tensor(lengths, device=torch.device("cuda")))
# lengths = torch.cat(all_lengths, dim=0).tolist()
# if torch.distributed.get_rank() == 0:
# import pdb; pdb.set_trace()
self.lengths = lengths
self.generator = generator
self.group_by_modality = group_by_modality
def __len__(self):
return len(self.lengths)
def __iter__(self):
if self.group_by_modality:
indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
else:
indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
return iter(indices)
class MagmaTrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_moving_average = None
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
if self.args.group_by_modality_length:
lengths = self.train_dataset.modality_lengths
return LengthGroupedSampler(
self.args.train_batch_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps,
lengths=lengths,
group_by_modality=True,
)
else:
return super()._get_train_sampler()
def evaluation_loop(self, dataloader, description: str, prediction_loss_only: Optional[bool] = None, ignore_keys=None, metric_key_prefix: str = "eval"):
"""
Override the `evaluation_loop` method for custom evaluation.
"""
# Custom logic before evaluation loop starts
print(f"Starting custom evaluation loop: {description}")
pass
synchronize() # This ensures all GPU operations are completed
# Initialize containers for predictions and labels
all_preds = []
all_labels = []
# Iterate over the evaluation data loader
for step, inputs in enumerate(dataloader):
# Optionally, apply the data collator manually here if needed
# inputs = self.data_collator(inputs)
# Move batch to the appropriate device
inputs = self._prepare_inputs(inputs)
# Disable gradient calculation during evaluation
with torch.no_grad():
outputs = self.model(**inputs)
# Extract logits (predictions) and labels
logits = outputs.logits
labels = inputs['labels']
# Collect predictions and labels for this batch
preds = logits.argmax(dim=-1) # Assuming classification task
all_preds.append(preds.cpu().numpy())
all_labels.append(labels.cpu().numpy())
# Concatenate all batches to get complete predictions and labels
all_preds = np.concatenate(all_preds, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
# Custom metric computation logic can be applied here
metrics = self.compute_metrics((all_preds, all_labels))
print(f"Finished custom evaluation loop. Metrics: {metrics}")
# Return the results as expected by the trainer (you can customize this)
return metrics
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = []
for n, p in opt_model.named_parameters():
if p.requires_grad:
if "mm_projector" in n and self.args.mm_projector_lr is not None:
optimizer_grouped_parameters.append(
{
"params": [p],
"weight_decay": self.args.weight_decay if n in decay_parameters else 0.0,
"lr": self.args.mm_projector_lr,
}
)
elif "vision_tower" in n and self.args.vision_tokenizer_lr is not None:
optimizer_grouped_parameters.append(
{
"params": [p],
"weight_decay": self.args.weight_decay if n in decay_parameters else 0.0,
"lr": self.args.vision_tokenizer_lr,
}
)
else:
optimizer_grouped_parameters.append(
{
"params": [p],
"weight_decay": self.args.weight_decay if n in decay_parameters else 0.0,
}
)
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
return self.optimizer
def _save_checkpoint(self, model, trial, metrics=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
# Only save Adapter
keys_to_match = ['mm_projector', 'vision_resampler', 'segtok_']
if getattr(self.args, "use_im_start_end", False):
keys_to_match.extend(['embed_tokens', 'embed_in'])
if getattr(self.args, "tune_vision_tokenizer", 'none') == "posembed":
keys_to_match.extend(['posembed'])
elif getattr(self.args, "tune_vision_tokenizer", 'none') == "decoder":
keys_to_match.extend(['sem_seg_head.predictor'])
elif getattr(self.args, "tune_vision_tokenizer", 'none') == "all":
keys_to_match.extend(['vision_tower'])
weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
if self.args.local_rank == 0 or self.args.local_rank == -1:
self.model.config.save_pretrained(output_dir)
print(f"keys to match: {keys_to_match}")
print(f"save checkpoint to {os.path.join(output_dir, f'mm_projector.bin')}")
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
else:
super(MagmaTrainer, self)._save_checkpoint(model, trial, metrics)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
pass
else:
super(MagmaTrainer, self)._save(output_dir, state_dict)
import torch
import torchvision
from torch.utils.data import DataLoader
import os
import sys
import argparse
from typing import Dict, Optional, Sequence, List
from dataclasses import dataclass, field
import clip
import multiprocessing as mp
from dataloader import *
import threading
import json
import pickle
parser = argparse.ArgumentParser('')
parser.add_argument('--dataset_name', type=str, default="video-dataset", metavar='DN',
help='dataset name for finding annotation files')
parser.add_argument('--trace_path', type=str, default="/pickle/file/format", metavar='TP',
help='path to a file which is a list containing paths of extracted traces')
parser.add_argument('--clip_score_dir', type=str, default="/path/to/dir/clip_filtered_scores/", metavar='CSD',
help='path to directory which contains the clip scores')
parser.add_argument('--output_path', type=str, default="/path/to/output/file", metavar='OVD',
help='path to output json file containing list of valid traces')
parser.add_argument('--min_score', type=float, default=0.25, metavar='MS',
help='number of frames to use per video')
parser.add_argument('--split_idx', type=int, default=0, metavar='SI',
help='index for splitting entire dataset over multiple GPUs')
parser.add_argument('--num_samples_per_segment', type=int, default=10400145, metavar='NS',
help='specify number of segments per GPU')
parser.add_argument('--num_workers', type=int, default=8, metavar='NW',
help='number of worker processes')
parser.add_argument('--batch_size', type=int, default=128, metavar='BS',
help='batch size')
parser.add_argument('--thread_num', type=int, default=72, metavar='TN',
help='number of threads')
valid_traces = []
full_set_traces = []
def main():
global args
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
output_path = args.output_path
all_traces = list(pickle.load(open(args.trace_path, 'rb'))) # This should be a list that contains the paths to all extracted traces
print('all_traces: ', len(all_traces))
print('')
lock = threading.Lock()
# Function that writes to the global set
def add_to_set(tid, split_traces):
print('split %s: ' % tid, len(split_traces))
print('')
for idx, trace in enumerate(split_traces):
if tid == 0 and idx % 1000 == 0:
print(idx)
trace_path = trace[0] + '/' + trace[-1]
score_path = os.path.join(args.clip_score_dir, trace[0], '%s.pth' % trace[-1])
try:
trace_score = torch.load(score_path, map_location='cpu').max()
if trace_score >= args.min_score:
global valid_traces
with lock: # Ensure that only one thread can modify the set at a time
valid_traces.append(trace_path)
global full_set_traces
with lock:
full_set_traces.append(trace_path)
except:
continue
# Create threads
per_process_video_num = len(all_traces) // args.thread_num
threads = []
for i in range(args.thread_num):
if i == args.thread_num - 1:
sub_files = all_traces[i * per_process_video_num :]
else:
sub_files = all_traces[i * per_process_video_num : (i + 1) * per_process_video_num]
t = threading.Thread(target=add_to_set, args=(i, sub_files,))
threads.append(t)
t.start()
# Wait for all threads to finish
for t in threads:
t.join()
json.dump(valid_traces, open(output_path, 'w'))
print('valid_traces: ', len(valid_traces))
print('full_set_traces: ', len(full_set_traces))
if __name__ == "__main__":
main()
\ No newline at end of file
import torch
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from torchvision.transforms import Resize
from torchvision.transforms import ToPILImage
import os
import sys
import argparse
from typing import Dict, Optional, Sequence, List
from dataclasses import dataclass, field
import clip
import concurrent.futures
parser = argparse.ArgumentParser('')
parser.add_argument('--dataset_name', type=str, default="video-dataset", metavar='DN',
help='dataset name for finding annotation files')
parser.add_argument('--trace_path', type=str, default="/path/to/dir/for/extracted/traces", metavar='TP',
help='path to directory with extracted traces')
parser.add_argument('--ann_path', type=str, default="/path/to/processed/language/annotations", metavar='AP',
help='path to language annotations. Process this into a dictionary in any format you prefer. We save this as a pickle file.')
parser.add_argument('--input_video_dir', type=str, default="/path/to/video/dir", metavar='IVD',
help='path to input video directory')
parser.add_argument('--output_dir', type=str, default="/path/to/clip_filtered_scores/", metavar='OD',
help='path to store clip filtered scores')
parser.add_argument('--filtered', type=bool, default=True, metavar='FIL',
help='boolean flag to specify if a filtered list is used')
parser.add_argument('--num_frames', type=int, default=4, metavar='NF',
help='number of frames to use per video')
parser.add_argument('--max_segment_frames', type=int, default=16, metavar='MSF',
help='maximum number of frames per segment')
parser.add_argument('--num_workers', type=int, default=8, metavar='NW',
help='number of worker processes')
parser.add_argument('--batch_size', type=int, default=2, metavar='BS',
help='batch size')
@dataclass
class DataCollator(object):
"""Collate examples."""
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
frames = []
text = []
num_frames = []
segment = []
for instance in instances:
frames.append(instance['frames'])
text += instance['text']
num_frames.append(instance['num_frames'])
segment.append(instance['segment'])
frames = torch.stack(frames)
return {'frames': frames , 'text': text, 'num_frames': num_frames, 'segment': segment}
class VideoDataset(Dataset):
def __init__(self, args, img_processor):
self.args = args
self.dataset_name = args.dataset_name
self.video_dir = args.input_video_dir
self.max_segment_frames = args.max_segment_frames
self.num_frames = args.num_frames
self.split_idx = args.split_idx
self.num_samples_per_segment = args.num_samples_per_segment
self.filtered = args.filtered
self.img_processor = img_processor
self.to_pil = ToPILImage()
self.vid2anns = pickle.load(open(args.ann_path, 'rb')) # a nested map from video to language annotations. For example, the keys at the first level will be the video and those at the second level are segments (denoted by start and end times). The values are the language annotations.
self.all_traces = pickle.load(open(args.trace_path, 'rb'))
def __len__(self):
return len(self.all_traces)
def __getitem__(self, idx):
curr_segment = self.all_traces[idx]
vid, trace_id = curr_segment
tmp = trace_id.split('_trace_')
clip = tmp[0]
times = tmp[-1].split('_')
start = int(times[0])
end = int(times[1])
intervals = trace_id.split('___')
time_start = float(intervals[0].split('_')[-1])
time_end = float(intervals[1].split('_')[-1])
try:
video_path = os.path.join(self.video_dir, vid, '%s.mp4' % clip)
frames, _, _ = torchvision.io.read_video(video_path)
frames = frames.permute(0, 3, 1, 2)
selected_frames = []
selected_indices = torch.linspace(start, end-1, min(self.num_frames, end-start+1))
for i in selected_indices:
selected_frames.append(self.img_processor(self.to_pil(frames[int(i)])))
selected_frames = torch.stack(selected_frames)
trace_num_frames = len(selected_frames)
if len(selected_frames) < self.num_frames:
last = selected_frames[-1].unsqueeze(0).repeat(self.num_frames-len(selected_frames), 1, 1, 1)
selected_frames = torch.cat((selected_frames, last), dim=0)
time_start, time_end, _ = clip.split('___')
time_start = float(time_start.split('start_')[-1])
time_end = float(time_end.split('end_')[-1])
if self.dataset_name == 'howto100m':
text = self.vid2anns[vid][(time_start, time_end)]
else:
narr = int(clip.split('narr_')[-1])
text = self.vid2anns[vid][(time_start, time_end, narr)]
except:
selected_frames = torch.zeros((self.num_frames, 3, 336, 336))
text = 'none'
trace_num_frames = 0
return {'frames': selected_frames, 'text': [text], 'num_frames': trace_num_frames, 'segment': [curr_segment]}
def save_tensor(tensor, file_path):
torch.save(tensor, file_path)
def main():
global args
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
import clip
model, preprocess = clip.load("ViT-L/14@336px", device='cuda')
model.eval()
dataset = VideoDataset(args, preprocess)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
num_workers=args.num_workers,
collate_fn=DataCollator()
)
for idx, sample in enumerate(dataloader):
print(idx)
frames = sample['frames']
text = sample['text']
num_frames = sample['num_frames']
segment = sample['segment']
batch_size = len(frames)
frames = frames.view(-1, frames.size(-3), frames.size(-2), frames.size(-1)).to('cuda')
text_tokens = clip.tokenize(text, truncate=True).to('cuda')
text_tokens = text_tokens.unsqueeze(1).repeat(1, args.num_frames, 1)
text_tokens = text_tokens.view(-1, text_tokens.size(-1))
with torch.no_grad():
frame_features = model.encode_image(frames)
text_features = model.encode_text(text_tokens)
# normalized features
frame_features = frame_features / frame_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
scores = torch.sum(frame_features * text_features, dim=-1)
scores = scores.view(batch_size, -1)
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
futures = []
for sample_idx in range(len(scores)):
curr_num = num_frames[sample_idx]
if curr_num == 0:
continue
curr_scores = scores[sample_idx][:curr_num]
curr_vid, curr_trace = segment[sample_idx][0] # each video (curr_vid) is split into segments that are further split into windows of 16 frames (curr_trace). E.g, Video1, segment_4_frames_16_32
output_vid_path = os.path.join(args.output_dir, curr_vid)
curr_output_path = os.path.join(output_vid_path, '%s.pth' % curr_trace)
if not os.path.exists(output_vid_path):
os.mkdir(output_vid_path)
futures.append(executor.submit(save_tensor, curr_scores.cpu(), curr_output_path))
concurrent.futures.wait(futures)
if __name__ == "__main__":
main()
\ No newline at end of file
import torch
import json
import cv2
import os
import sys
import csv
import pickle
import argparse
import random
import numpy as np
import multiprocessing as mp
import imageio
from scenedetect import detect, ContentDetector
parser = argparse.ArgumentParser('')
parser.add_argument('--ann_path', type=str, default="/path/to/json/file", metavar='AP',
help='path to json file that contains video and language annotations. See lines 169 - 172 for more detail.')
parser.add_argument('--video_dir', type=str, default="/path/to/video/directory", metavar='VD',
help='path to video dir')
parser.add_argument('--temp_video_segment_dir', type=str, default="./temp_video_segments", metavar='TD',
help='temporary directory to store split video segments')
parser.add_argument('--output_segment_dir', type=str, default="./detected_video_segments", metavar='OD',
help='path to store the final detected video segments')
parser.add_argument('--target_fps', type=int, default=None, metavar='TFPS',
help='FPS to sample frames')
parser.add_argument('--thread_num', type=int, default=8, metavar='TN',
help='number of threads')
def extract_video_frames_and_metadata(video_path, target_fps=1):
'''
Extracts video frames at 1 fps by default
'''
cap = cv2.VideoCapture(video_path)
vid_fps = cap.get(cv2.CAP_PROP_FPS)
round_vid_fps = round(vid_fps)
num_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
duration = num_frames / round_vid_fps
if target_fps is not None:
hop = round(vid_fps / target_fps)
all_frames = []
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if target_fps is not None:
if frame_idx % hop == 0:
all_frames.append(frame)
else:
all_frames.append(frame)
frame_idx += 1
cap.release()
return vid_fps, num_frames, duration, all_frames
def write_video(video, output_path, write_fps):
wide_list = list(video.unbind(1))
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
video_writer = imageio.get_writer(output_path, fps=write_fps)
for frame in wide_list[2:-1]:
video_writer.append_data(frame)
video_writer.close()
return
def extract_num_frames(video_path):
'''
Extracts video frames at 1 fps
'''
cap = cv2.VideoCapture(video_path)
vid_fps = cap.get(cv2.CAP_PROP_FPS)
round_vid_fps = round(vid_fps)
num_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
return num_frames, vid_fps, cap
def process_single_vid(vid, vid_anns, min_part_duration=3):
vid_path = os.path.join(args.video_dir, '%s.mp4' % vid)
num_frames, vid_fps, cap = extract_num_frames(vid_path)
vid_fps = int(vid_fps)
start = vid_anns['start']
end = vid_anns['end']
text = vid_anns['text']
for idx, curr_start in enumerate(start):
curr_end = end[idx]
curr_text = text[idx]
full_segment_path = os.path.join(args.temp_video_segment_dir, '%s___start_%s___end_%s.mp4' % (vid, curr_start, curr_end))
actual_start = curr_start * vid_fps
actual_end = curr_end * vid_fps
actual_num_frames = int(actual_end - actual_start + 1)
cap.set(cv2.CAP_PROP_POS_FRAMES, actual_start)
all_frames = []
for frame_idx in range(actual_num_frames):
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
all_frames.append(frame)
all_frames = np.stack(all_frames)
all_frames = torch.from_numpy(all_frames)
all_frames = all_frames.permute(0, 3, 1, 2)[None].byte()
write_video(all_frames, full_segment_path, vid_fps)
scene_list = detect(full_segment_path, ContentDetector())
output_dir = os.path.join(args.output_segment_dir, vid)
if not os.path.exists(output_dir):
os.mkdir(output_dir)
if len(scene_list) == 0:
split_segment_path = os.path.join(output_dir, 'start_%s___end_%s___part_%s.mp4' % (curr_start, curr_end, 0))
if os.path.exists(split_segment_path):
continue
write_video(all_frames, split_segment_path, vid_fps)
else:
for part_idx, scene in enumerate(scene_list):
first = scene[0].get_frames()
second = scene[1].get_frames()
split_segment_path = os.path.join(output_dir, 'start_%s___end_%s___part_%s.mp4' % (curr_start, curr_end, part_idx))
if os.path.exists(split_segment_path):
continue
write_video(all_frames[:, first:second+1], split_segment_path, vid_fps)
cmd = 'rm %s' % full_segment_path
os.system(cmd)
return
def sub_processor(pid, files, data):
print(pid, ' : ', len(files))
for curr_vid in files[:]:
try:
vid_anns = data[curr_vid]
process_single_vid(curr_vid, vid_anns)
except:
continue
def main():
global args
args = parser.parse_args()
if not os.path.exists(args.output_segment_dir):
os.mkdir(args.output_segment_dir)
# This assumes that we have language annotations that are stored in a nested dictionary
# The keys at the first level are the video names or ids
# The values are dictionaries that contain the start and end times as well as the text
# This format can be easily modified to suit different datasets.
data = json.load(open(args.ann_path))
video2anns = {}
for idx, vid in enumerate(data):
if idx % 100 == 0:
print(idx)
curr = data[vid]
start = curr['start']
end = curr['end']
text = curr['text']
if vid not in video2anns:
video2anns[vid] = {}
video2anns[vid][start] = narr
#json.dump(video2anns, open('/path/to/vid_to_anns.json', 'w'))
all_vids = list(video2anns.keys())
print('all_vids: ', len(all_vids))
print('')
processes = []
video_num = len(all_vids)
per_process_video_num = video_num // args.thread_num
for i in range(args.thread_num):
if i == args.thread_num - 1:
sub_files = all_vids[i * per_process_video_num :]
else:
sub_files = all_vids[i * per_process_video_num : (i + 1) * per_process_video_num]
p = mp.Process(target=sub_processor, args=(i, sub_files, video2anns))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment