Commit ef30d662 authored by bailuo's avatar bailuo
Browse files

init

parents
Pipeline #2496 failed with stages
in 0 seconds
from enum import Enum
import numpy as np
import torch
import torch.distributed as dist
from transformers import PreTrainedModel
from typing import List, Optional
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_EOS_TOKEN = '</s>'
DEFAULT_BOS_TOKEN = '<s>'
DEFAULT_UNK_TOKEN = '<unk>'
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
DEFAULT_BBOX_TOKEN = "<bbox>"
# Modified from https://github.com/haotian-liu/LLaVA/blob/82fc5e0e5f4393a4c26851fa32c69ab37ea3b146/llava/model/llava_arch.py#L99 # noqa: E501
def prepare_inputs_labels_for_multimodal(
llm: PreTrainedModel,
input_ids: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
**kwargs):
if pixel_values is None:
kwargs.update({
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attention_mask,
'past_key_values': past_key_values,
'inputs_embeds': None,
'labels': labels
})
return kwargs
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# remove the padding using attention_mask -- TODO: double check
input_ids = [
cur_input_ids[cur_attention_mask]
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
]
labels = [
cur_labels[cur_attention_mask]
for cur_labels, cur_attention_mask in zip(labels, attention_mask)
]
new_inputs_embeds = []
new_labels = []
new_input_ids = []
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
if num_images == 0:
cur_pixel_values = pixel_values[cur_image_idx]
cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids)
cur_inputs_embeds = torch.cat([cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0)
new_inputs_embeds.append(cur_inputs_embeds)
new_labels.append(labels[batch_idx])
new_input_ids.append(cur_input_ids)
cur_image_idx += 1
continue
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]])
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]])
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_inputs_embeds = llm.get_input_embeddings()(torch.cat(cur_input_ids_noim))
cur_inputs_embeds_no_im = torch.split(cur_inputs_embeds, split_sizes, dim=0)
cur_new_inputs_embeds = []
cur_new_labels = []
cur_new_input_ids = []
for i in range(num_images + 1):
cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
cur_new_input_ids.append(cur_input_ids_noim[i])
if i < num_images:
cur_pixel_values = pixel_values[cur_image_idx]
cur_image_idx += 1
cur_new_inputs_embeds.append(cur_pixel_values)
cur_new_labels.append(torch.full((cur_pixel_values.shape[0], ), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
cur_new_input_ids.append(torch.full((cur_pixel_values.shape[0], ), IMAGE_TOKEN_INDEX, device=cur_input_ids.device, dtype=cur_input_ids.dtype))
cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds)
cur_new_labels = torch.cat(cur_new_labels)
cur_new_input_ids = torch.cat(cur_new_input_ids)
new_inputs_embeds.append(cur_new_inputs_embeds)
new_labels.append(cur_new_labels)
new_input_ids.append(cur_new_input_ids)
# Combine them
max_len = max(x.shape[0] for x in new_inputs_embeds)
batch_size = len(new_inputs_embeds)
new_inputs_embeds_padded = []
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
new_input_ids_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_input_ids[0].dtype, device=new_input_ids[0].device)
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
for i, (cur_new_embed, cur_new_labels, cur_new_input_ids) in enumerate(zip(new_inputs_embeds, new_labels, new_input_ids)):
cur_len = cur_new_embed.shape[0]
new_inputs_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
new_input_ids_padded[i, :cur_len] = cur_new_input_ids
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
new_input_ids = new_input_ids_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
kwargs.update({
'input_ids': None,
'position_ids': position_ids,
'attention_mask': attention_mask,
'past_key_values': past_key_values,
'inputs_embeds': new_inputs_embeds,
'labels': new_labels,
'new_input_ids': new_input_ids
})
return kwargs
class Summary(Enum):
NONE = 0
AVERAGE = 1
SUM = 2
COUNT = 3
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
self.name = name
self.fmt = fmt
self.summary_type = summary_type
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def all_reduce(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(self.sum, np.ndarray):
total = torch.tensor(
self.sum.tolist()
+ [
self.count,
],
dtype=torch.float32,
device=device,
)
else:
total = torch.tensor(
[self.sum, self.count], dtype=torch.float32, device=device
)
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
if total.shape[0] > 2:
self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item()
else:
self.sum, self.count = total.tolist()
self.avg = self.sum / (self.count + 1e-5)
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = ""
if self.summary_type is Summary.NONE:
fmtstr = ""
elif self.summary_type is Summary.AVERAGE:
fmtstr = "{name} {avg:.3f}"
elif self.summary_type is Summary.SUM:
fmtstr = "{name} {sum:.3f}"
elif self.summary_type is Summary.COUNT:
fmtstr = "{name} {count:.3f}"
else:
raise ValueError("invalid summary type %r" % self.summary_type)
return fmtstr.format(**self.__dict__)
def intersectionAndUnionGPU(output, target, K, ignore_index=255):
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
assert output.dim() in [1, 2, 3]
assert output.shape == target.shape
output = output.view(-1)
target = target.view(-1)
output[target == ignore_index] = ignore_index
intersection = output[output == target]
area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1)
area_output = torch.histc(output, bins=K, min=0, max=K - 1)
area_target = torch.histc(target, bins=K, min=0, max=K - 1)
area_union = area_output + area_target - area_intersection
return area_intersection, area_union, area_target
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print("\t".join(entries))
def display_summary(self):
entries = [" *"]
entries += [meter.summary() for meter in self.meters]
print(" ".join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = "{:" + str(num_digits) + "d}"
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
def dict_to_cuda(input_dict):
for k, v in input_dict.items():
if isinstance(input_dict[k], torch.Tensor):
input_dict[k] = v.cuda(non_blocking=True)
elif isinstance(v, list) and len(v) > 0:
input_dict[k] = [ele.cuda(non_blocking=True) if isinstance(ele, torch.Tensor) else ele for ele in v]
return input_dict
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from torch.optim import AdamW
from transformers import AutoTokenizer
from xtuner.dataset import ConcatDataset
from xtuner.dataset.samplers import LengthGroupedSampler
from xtuner.engine.runner import TrainLoop
from xtuner.utils import PROMPT_TEMPLATE
from xtuner.dataset.map_fns import template_map_fn_factory
from third_parts.mmdet.models.losses import DiceLoss, CrossEntropyLoss
from peft import LoraConfig
from projects.llava_sam2.models.internvl import InternVL_Slowfast
from projects.llava_sam2.models import VideoLLaVASAMModel, SAM2TrainRunner, VideoLLaVASAMModel_zero3
from projects.llava_sam2.datasets import VideoReVOSDataset, VideoMeVISDataset, VideoRefYoutubeVOSDataset, video_lisa_collate_fn, VideoSAM2Dataset
from projects.llava_sam2.datasets import VideoChatUniViDataset
from projects.llava_sam2.datasets import RefCOCOgGCGDataset, OpenPsgGCGDataset, FlickrGCGDataset, GranDfGCGDataset, OspreyDataset, OspreyDescriptionDataset, OspreyShortDescriptionDataset
from projects.llava_sam2.datasets import LLaVADataset
from projects.llava_sam2.datasets import ReferSegmDataset
from projects.llava_sam2.models.preprocess.image_resize import DirectResize
#######################################################################
# PART 1 Settings #
#######################################################################
# Model
path = './pretrained/InternVL2_5-1B/'
pretrained_pth = None
# Data
template = "qwen_chat"
prompt_template = PROMPT_TEMPLATE.qwen_chat
max_length = 8192
# Scheduler & Optimizer
batch_size = 2 # per_device
accumulative_counts = 4
dataloader_num_workers = 4
max_epochs = 1
optim_type = AdamW
# official 1024 -> 4e-5
# lr = 1e-6
lr = 4e-5
betas = (0.9, 0.999)
weight_decay = 0.05
max_norm = 1 # grad clip
warmup_ratio = 0.05
# Save
save_steps = 1000
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
special_tokens = ['[SEG]', '<p>', '</p>', '<vp>', '</vp>']
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=path,
trust_remote_code=True,
padding_side='right')
extra_image_processor = dict(
type=DirectResize,
target_length=1024,
)
#######################################################################
# PART 2 Model & Tokenizer & Image Processor #
#######################################################################
model = dict(
type=VideoLLaVASAMModel_zero3,
special_tokens=special_tokens,
frozen_sam2_decoder=False,
mllm=dict(
type=InternVL_Slowfast,
model_path=path,
freeze_llm=True,
freeze_visual_encoder=True,
llm_lora=dict(
type=LoraConfig,
r=128,
lora_alpha=256,
lora_dropout=0.05,
bias='none',
task_type='CAUSAL_LM'),
special_tokens=special_tokens,
),
tokenizer=tokenizer,
grounding_encoder=dict(
type=SAM2TrainRunner,
),
loss_mask=dict(
type=CrossEntropyLoss,
use_sigmoid=True,
reduction='mean',
loss_weight=2.0),
loss_dice=dict(
type=DiceLoss,
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=True,
eps=1.0,
loss_weight=0.5),
pretrained_pth=pretrained_pth,
loss_sample_points=True,
# loss_sample_points=False,
bs=batch_size,
)
#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
DATA_ROOT = './data/'
VIDEO_DATA_ROOT = DATA_ROOT + 'video_datas/'
############### video res
data_root_revos = VIDEO_DATA_ROOT + 'revos/'
video_revos_image_folder = data_root_revos
video_revos_expression_file = data_root_revos + 'meta_expressions_train_.json'
video_revos_mask_file = data_root_revos + 'mask_dict.json'
data_root_mevis = VIDEO_DATA_ROOT + 'mevis/train/'
video_mevis_image_folder = data_root_mevis + 'JPEGImages'
video_mevis_expression_file = data_root_mevis + 'meta_expressions.json'
video_mevis_mask_file = data_root_mevis + 'mask_dict.json'
data_root_refytvos = VIDEO_DATA_ROOT + 'rvos/'
video_refytvos_image_folder = data_root_refytvos + 'train/JPEGImages/'
video_refytvos_expression_file = data_root_refytvos + 'meta_expressions/train/meta_expressions.json'
video_refytvos_mask_file = data_root_refytvos + 'mask_dict.pkl'
video_revos_dataset = dict(
type=VideoReVOSDataset,
image_folder=video_revos_image_folder,
expression_file=video_revos_expression_file,
mask_file=video_revos_mask_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=10,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
sampled_frames=5,
)
video_mevis_dataset = dict(
type=VideoMeVISDataset,
image_folder=video_mevis_image_folder,
expression_file=video_mevis_expression_file,
mask_file=video_mevis_mask_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=4,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
sampled_frames=5,
)
video_refytvos_dataset = dict(
type=VideoRefYoutubeVOSDataset,
image_folder=video_refytvos_image_folder,
expression_file=video_refytvos_expression_file,
mask_file=video_refytvos_mask_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=4,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
sampled_frames=5,
)
################### Video chat
data_root_video_chatunivi = VIDEO_DATA_ROOT + 'chat_univi/'
video_chatunivi_image_folder = data_root_video_chatunivi + 'Activity_Videos/'
video_chatunivi_json_file = data_root_video_chatunivi+ 'video_chat.json'
video_qa_dataset = dict(
type=VideoChatUniViDataset,
image_folder=video_chatunivi_image_folder,
json_file=video_chatunivi_json_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
sampled_frames=5,
)
################## image chat
LLAVA_ROOT = DATA_ROOT + 'llava_data/'
llava_vqa_dataset = dict(
type=LLaVADataset,
tokenizer=tokenizer,
data_path=LLAVA_ROOT + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json',
prompt_template=prompt_template,
special_tokens=special_tokens,
image_folder=LLAVA_ROOT + 'llava_images/',
)
################## image res
RES_ROOT = DATA_ROOT + 'ref_seg/'
refcoco_segm_dataset=dict(
type=ReferSegmDataset,
tokenizer=tokenizer,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
data_root=RES_ROOT + 'refcoco',
data_prefix=dict(img_path='coco2014/train2014/'),
ann_file='instances.json',
split_file='refs(unc).p',
prompt_template=prompt_template,
num_classes_per_sample=5,
max_length=max_length,
)
refcoco_plus_segm_dataset=dict(
type=ReferSegmDataset,
tokenizer=tokenizer,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
data_root=RES_ROOT + 'refcoco+',
data_prefix=dict(img_path='coco2014/train2014/'),
ann_file='instances.json',
split_file='refs(unc).p',
prompt_template=prompt_template,
num_classes_per_sample=5,
max_length=max_length,
)
refcocog_segm_dataset=dict(
type=ReferSegmDataset,
tokenizer=tokenizer,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
data_root= RES_ROOT + 'refcocog',
data_prefix=dict(img_path='coco2014/train2014/'),
ann_file='instances.json',
split_file='refs(umd).p',
prompt_template=prompt_template,
num_classes_per_sample=5,
max_length=max_length,
)
# image gcg datas
glamm_data_root = DATA_ROOT + 'glamm_data/'
refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
grandf_image_path = glamm_data_root + 'images/grandf/train/'
grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
psg_image_path = glamm_data_root + 'images/coco2017/'
psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
glamm_refcocog_dataset = dict(
type=RefCOCOgGCGDataset,
image_folder=refcocog_image_path,
data_path=refcocog_ann_file,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
extra_image_processor=extra_image_processor,
lazy=True,
repeats=1,
)
glamm_grandf_dataset = dict(
type=GranDfGCGDataset,
data_path=grandf_ann_file,
image_folder=grandf_image_path,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
extra_image_processor=extra_image_processor,
lazy=True,
repeats=10,
)
glamm_psg_dataset = dict(
type=OpenPsgGCGDataset,
data_path=psg_ann_file,
image_folder=psg_image_path,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
extra_image_processor=extra_image_processor,
lazy=True,
repeats=1,
)
glamm_flickr_dataset = dict(
type=FlickrGCGDataset,
data_path=flickr_ann_file,
image_folder=flickr_image_path,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
extra_image_processor=extra_image_processor,
lazy=True,
repeats=1,
)
# sam2 data
data_sam2_folder = VIDEO_DATA_ROOT + 'sam_v_full/'
data_sam2_expression_file = VIDEO_DATA_ROOT + 'sam_v_final_v3.json'
video_sam2_dataset = dict(
type=VideoSAM2Dataset,
sam2_folder=data_sam2_folder,
expression_file=data_sam2_expression_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=4,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
sampled_frames=5,
select_number=5,
)
# osprey
OSPREY_ROOT = DATA_ROOT + "osprey-724k/"
data_osprey_file = OSPREY_ROOT + 'Osprey-724K/osprey_conversation.json'
data_osprey_image_folders = [
OSPREY_ROOT + 'coco/train2014/',
OSPREY_ROOT + 'coco/val2014/',
OSPREY_ROOT + 'coco/train2017/',
OSPREY_ROOT + 'coco/val2017/',
]
image_osprey_dataset = dict(
type=OspreyDataset,
image_folder=data_osprey_image_folders,
data_path=data_osprey_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
)
data_osprey_detail_description_file = OSPREY_ROOT + 'Osprey-724K/osprey_detail_description.json'
image_osprey_description_dataset = dict(
type=OspreyDescriptionDataset,
image_folder=data_osprey_image_folders,
data_path=data_osprey_detail_description_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
)
data_osprey_short_file = OSPREY_ROOT + 'Osprey-724K/osprey_short_form.json'
image_osprey_short_dataset = dict(
type=OspreyShortDescriptionDataset,
image_folder=data_osprey_image_folders,
data_path=data_osprey_short_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
)
data_osprey_part_file = OSPREY_ROOT + 'Osprey-724K/osprey_part_level.json'
image_osprey_part_dataset = dict(
type=OspreyDataset,
image_folder=data_osprey_image_folders,
data_path=data_osprey_part_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
)
data_osprey_positive_neg_file = OSPREY_ROOT + 'Osprey-724K/osprey_lvis_positive_negative.json'
image_osprey_positive_neg_dataset = dict(
type=OspreyDataset,
image_folder=data_osprey_image_folders,
data_path=data_osprey_positive_neg_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
)
train_dataset = dict(
type=ConcatDataset, datasets=[
# sem seg
# semantic_seg_ade20k_dataset,
# ref seg
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
# image qa
llava_vqa_dataset,
# video res
video_mevis_dataset, video_revos_dataset, video_refytvos_dataset,
# video chat
video_qa_dataset,
# sam2 pesudo
video_sam2_dataset,
# gcg data
glamm_psg_dataset,
glamm_grandf_dataset,
glamm_flickr_dataset,
glamm_refcocog_dataset,
# visual prompt
image_osprey_dataset, image_osprey_description_dataset,
image_osprey_part_dataset, image_osprey_short_dataset,
image_osprey_positive_neg_dataset,
]
)
train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=train_dataset,
sampler=dict(
type=LengthGroupedSampler,
length_property='modality_length',
per_device_batch_size=batch_size * accumulative_counts),
collate_fn=dict(type=video_lisa_collate_fn)
)
#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='bfloat16'
)
# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = [
dict(
type=LinearLR,
start_factor=1e-5,
by_epoch=True,
begin=0,
end=warmup_ratio * max_epochs,
convert_to_iter_based=True),
dict(
type=CosineAnnealingLR,
eta_min=0.0,
by_epoch=True,
begin=warmup_ratio * max_epochs,
end=max_epochs,
convert_to_iter_based=True)
]
# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
# dict(type=DatasetInfoHook, tokenizer=tokenizer),
]
# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 10 iterations.
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per `save_steps`.
checkpoint=dict(
type=CheckpointHook,
save_optimizer=False,
by_epoch=False,
interval=save_steps,
max_keep_ckpts=save_total_limit),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)
# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)
# set visualizer
visualizer = None
# set log level
log_level = 'INFO'
# load from which checkpoint
load_from = None
# whether to resume training from the loaded checkpoint
resume = False
# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)
# set log processor
log_processor = dict(by_epoch=False)
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from torch.optim import AdamW
from transformers import AutoTokenizer
from xtuner.dataset import ConcatDataset
from xtuner.dataset.samplers import LengthGroupedSampler
from xtuner.engine.runner import TrainLoop
from xtuner.utils import PROMPT_TEMPLATE
from xtuner.dataset.map_fns import template_map_fn_factory
from third_parts.mmdet.models.losses import DiceLoss, CrossEntropyLoss
from peft import LoraConfig
from projects.llava_sam2.models.internvl import InternVL_Slowfast
from projects.llava_sam2.models import VideoLLaVASAMModel, SAM2TrainRunner, VideoLLaVASAMModel_zero3
from projects.llava_sam2.datasets import VideoReVOSDataset, VideoMeVISDataset, VideoRefYoutubeVOSDataset, video_lisa_collate_fn, VideoSAM2Dataset
from projects.llava_sam2.datasets import VideoChatUniViDataset
from projects.llava_sam2.datasets import RefCOCOgGCGDataset, OpenPsgGCGDataset, FlickrGCGDataset, GranDfGCGDataset, OspreyDataset, OspreyDescriptionDataset, OspreyShortDescriptionDataset
from projects.llava_sam2.datasets import LLaVADataset
from projects.llava_sam2.datasets import ReferSegmDataset
from projects.llava_sam2.models.preprocess.image_resize import DirectResize
#######################################################################
# PART 1 Settings #
#######################################################################
# Model
path = './pretrained/InternVL2_5-4B'
pretrained_pth = None
# Data
template = "phi3_chat"
prompt_template = PROMPT_TEMPLATE.phi3_chat
max_length = 8192
# Scheduler & Optimizer
batch_size = 2 # per_device
accumulative_counts = 4
dataloader_num_workers = 4
max_epochs = 1
optim_type = AdamW
# official 1024 -> 4e-5
# lr = 1e-6
lr = 4e-5
betas = (0.9, 0.999)
weight_decay = 0.05
max_norm = 1 # grad clip
warmup_ratio = 0.05
# Save
save_steps = 1000
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
special_tokens = ['[SEG]', '<p>', '</p>', '<vp>', '</vp>']
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=path,
trust_remote_code=True,
padding_side='right')
extra_image_processor = dict(
type=DirectResize,
target_length=1024,
)
#######################################################################
# PART 2 Model & Tokenizer & Image Processor #
#######################################################################
model = dict(
type=VideoLLaVASAMModel_zero3,
special_tokens=special_tokens,
frozen_sam2_decoder=False,
mllm=dict(
type=InternVL_Slowfast,
model_path=path,
freeze_llm=True,
freeze_visual_encoder=True,
llm_lora=dict(
type=LoraConfig,
r=128,
lora_alpha=256,
lora_dropout=0.05,
bias='none',
task_type='CAUSAL_LM'),
special_tokens=special_tokens,
),
tokenizer=tokenizer,
grounding_encoder=dict(
type=SAM2TrainRunner,
),
loss_mask=dict(
type=CrossEntropyLoss,
use_sigmoid=True,
reduction='mean',
loss_weight=2.0),
loss_dice=dict(
type=DiceLoss,
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=True,
eps=1.0,
loss_weight=0.5),
pretrained_pth=pretrained_pth,
loss_sample_points=True,
# loss_sample_points=False,
bs=batch_size,
)
#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
DATA_ROOT = './data/'
VIDEO_DATA_ROOT = DATA_ROOT + 'video_datas/'
############### video res
data_root_revos = VIDEO_DATA_ROOT + 'revos/'
video_revos_image_folder = data_root_revos
video_revos_expression_file = data_root_revos + 'meta_expressions_train_.json'
video_revos_mask_file = data_root_revos + 'mask_dict.json'
data_root_mevis = VIDEO_DATA_ROOT + 'mevis/train/'
video_mevis_image_folder = data_root_mevis + 'JPEGImages'
video_mevis_expression_file = data_root_mevis + 'meta_expressions.json'
video_mevis_mask_file = data_root_mevis + 'mask_dict.json'
data_root_refytvos = VIDEO_DATA_ROOT + 'rvos/'
video_refytvos_image_folder = data_root_refytvos + 'train/JPEGImages/'
video_refytvos_expression_file = data_root_refytvos + 'meta_expressions/train/meta_expressions.json'
video_refytvos_mask_file = data_root_refytvos + 'mask_dict.pkl'
video_revos_dataset = dict(
type=VideoReVOSDataset,
image_folder=video_revos_image_folder,
expression_file=video_revos_expression_file,
mask_file=video_revos_mask_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=10,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
sampled_frames=5,
)
video_mevis_dataset = dict(
type=VideoMeVISDataset,
image_folder=video_mevis_image_folder,
expression_file=video_mevis_expression_file,
mask_file=video_mevis_mask_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=4,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
sampled_frames=5,
)
video_refytvos_dataset = dict(
type=VideoRefYoutubeVOSDataset,
image_folder=video_refytvos_image_folder,
expression_file=video_refytvos_expression_file,
mask_file=video_refytvos_mask_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=4,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
sampled_frames=5,
)
################### Video chat
data_root_video_chatunivi = VIDEO_DATA_ROOT + 'chat_univi/'
video_chatunivi_image_folder = data_root_video_chatunivi + 'Activity_Videos/'
video_chatunivi_json_file = data_root_video_chatunivi+ 'video_chat.json'
video_qa_dataset = dict(
type=VideoChatUniViDataset,
image_folder=video_chatunivi_image_folder,
json_file=video_chatunivi_json_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
sampled_frames=5,
)
################## image chat
LLAVA_ROOT = DATA_ROOT + 'llava_data/'
llava_vqa_dataset = dict(
type=LLaVADataset,
tokenizer=tokenizer,
data_path=LLAVA_ROOT + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json',
prompt_template=prompt_template,
special_tokens=special_tokens,
image_folder=LLAVA_ROOT + 'llava_images/',
)
################## image res
RES_ROOT = DATA_ROOT + 'ref_seg/'
refcoco_segm_dataset=dict(
type=ReferSegmDataset,
tokenizer=tokenizer,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
data_root=RES_ROOT + 'refcoco',
data_prefix=dict(img_path='coco2014/train2014/'),
ann_file='instances.json',
split_file='refs(unc).p',
prompt_template=prompt_template,
num_classes_per_sample=5,
max_length=max_length,
)
refcoco_plus_segm_dataset=dict(
type=ReferSegmDataset,
tokenizer=tokenizer,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
data_root=RES_ROOT + 'refcoco+',
data_prefix=dict(img_path='coco2014/train2014/'),
ann_file='instances.json',
split_file='refs(unc).p',
prompt_template=prompt_template,
num_classes_per_sample=5,
max_length=max_length,
)
refcocog_segm_dataset=dict(
type=ReferSegmDataset,
tokenizer=tokenizer,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
data_root= RES_ROOT + 'refcocog',
data_prefix=dict(img_path='coco2014/train2014/'),
ann_file='instances.json',
split_file='refs(umd).p',
prompt_template=prompt_template,
num_classes_per_sample=5,
max_length=max_length,
)
# image gcg datas
glamm_data_root = DATA_ROOT + 'glamm_data/'
refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
grandf_image_path = glamm_data_root + 'images/grandf/train/'
grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
psg_image_path = glamm_data_root + 'images/coco2017/'
psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
glamm_refcocog_dataset = dict(
type=RefCOCOgGCGDataset,
image_folder=refcocog_image_path,
data_path=refcocog_ann_file,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
extra_image_processor=extra_image_processor,
lazy=True,
repeats=1,
)
glamm_grandf_dataset = dict(
type=GranDfGCGDataset,
data_path=grandf_ann_file,
image_folder=grandf_image_path,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
extra_image_processor=extra_image_processor,
lazy=True,
repeats=10,
)
glamm_psg_dataset = dict(
type=OpenPsgGCGDataset,
data_path=psg_ann_file,
image_folder=psg_image_path,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
extra_image_processor=extra_image_processor,
lazy=True,
repeats=1,
)
glamm_flickr_dataset = dict(
type=FlickrGCGDataset,
data_path=flickr_ann_file,
image_folder=flickr_image_path,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
extra_image_processor=extra_image_processor,
lazy=True,
repeats=1,
)
# sam2 data
data_sam2_folder = VIDEO_DATA_ROOT + 'sam_v_full/'
data_sam2_expression_file = VIDEO_DATA_ROOT + 'sam_v_final_v3.json'
video_sam2_dataset = dict(
type=VideoSAM2Dataset,
sam2_folder=data_sam2_folder,
expression_file=data_sam2_expression_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=4,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
sampled_frames=5,
select_number=5,
)
# osprey
OSPREY_ROOT = DATA_ROOT + "osprey-724k/"
data_osprey_file = OSPREY_ROOT + 'Osprey-724K/osprey_conversation.json'
data_osprey_image_folders = [
OSPREY_ROOT + 'coco/train2014/',
OSPREY_ROOT + 'coco/val2014/',
OSPREY_ROOT + 'coco/train2017/',
OSPREY_ROOT + 'coco/val2017/',
]
image_osprey_dataset = dict(
type=OspreyDataset,
image_folder=data_osprey_image_folders,
data_path=data_osprey_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
)
data_osprey_detail_description_file = OSPREY_ROOT + 'Osprey-724K/osprey_detail_description.json'
image_osprey_description_dataset = dict(
type=OspreyDescriptionDataset,
image_folder=data_osprey_image_folders,
data_path=data_osprey_detail_description_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
)
data_osprey_short_file = OSPREY_ROOT + 'Osprey-724K/osprey_short_form.json'
image_osprey_short_dataset = dict(
type=OspreyShortDescriptionDataset,
image_folder=data_osprey_image_folders,
data_path=data_osprey_short_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
)
data_osprey_part_file = OSPREY_ROOT + 'Osprey-724K/osprey_part_level.json'
image_osprey_part_dataset = dict(
type=OspreyDataset,
image_folder=data_osprey_image_folders,
data_path=data_osprey_part_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
)
data_osprey_positive_neg_file = OSPREY_ROOT + 'Osprey-724K/osprey_lvis_positive_negative.json'
image_osprey_positive_neg_dataset = dict(
type=OspreyDataset,
image_folder=data_osprey_image_folders,
data_path=data_osprey_positive_neg_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
)
train_dataset = dict(
type=ConcatDataset, datasets=[
# sem seg
# semantic_seg_ade20k_dataset,
# ref seg
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
# image qa
llava_vqa_dataset,
# video res
video_mevis_dataset, video_revos_dataset, video_refytvos_dataset,
# video chat
video_qa_dataset,
# sam2 pesudo
video_sam2_dataset,
# gcg data
glamm_psg_dataset,
glamm_grandf_dataset,
glamm_flickr_dataset,
glamm_refcocog_dataset,
# visual prompt
image_osprey_dataset, image_osprey_description_dataset,
image_osprey_part_dataset, image_osprey_short_dataset,
image_osprey_positive_neg_dataset,
]
)
train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=train_dataset,
sampler=dict(
type=LengthGroupedSampler,
length_property='modality_length',
per_device_batch_size=batch_size * accumulative_counts),
collate_fn=dict(type=video_lisa_collate_fn)
)
#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='bfloat16'
)
# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = [
dict(
type=LinearLR,
start_factor=1e-5,
by_epoch=True,
begin=0,
end=warmup_ratio * max_epochs,
convert_to_iter_based=True),
dict(
type=CosineAnnealingLR,
eta_min=0.0,
by_epoch=True,
begin=warmup_ratio * max_epochs,
end=max_epochs,
convert_to_iter_based=True)
]
# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
# dict(type=DatasetInfoHook, tokenizer=tokenizer),
]
# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 10 iterations.
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per `save_steps`.
checkpoint=dict(
type=CheckpointHook,
save_optimizer=False,
by_epoch=False,
interval=save_steps,
max_keep_ckpts=save_total_limit),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)
# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)
# set visualizer
visualizer = None
# set log level
log_level = 'INFO'
# load from which checkpoint
load_from = None
# whether to resume training from the loaded checkpoint
resume = False
# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)
# set log processor
log_processor = dict(by_epoch=False)
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from torch.optim import AdamW
from transformers import AutoTokenizer
from xtuner.dataset import ConcatDataset
from xtuner.dataset.samplers import LengthGroupedSampler
from xtuner.engine.runner import TrainLoop
from xtuner.utils import PROMPT_TEMPLATE
from xtuner.dataset.map_fns import template_map_fn_factory
from third_parts.mmdet.models.losses import DiceLoss, CrossEntropyLoss
from peft import LoraConfig
from projects.llava_sam2.models.internvl import InternVL_Slowfast
from projects.llava_sam2.models import VideoLLaVASAMModel, SAM2TrainRunner, VideoLLaVASAMModel_zero3
from projects.llava_sam2.datasets import VideoReVOSDataset, VideoMeVISDataset, VideoRefYoutubeVOSDataset, video_lisa_collate_fn, VideoSAM2Dataset
from projects.llava_sam2.datasets import VideoChatUniViDataset
from projects.llava_sam2.datasets import RefCOCOgGCGDataset, OpenPsgGCGDataset, FlickrGCGDataset, GranDfGCGDataset, OspreyDataset, OspreyDescriptionDataset, OspreyShortDescriptionDataset
from projects.llava_sam2.datasets import LLaVADataset
from projects.llava_sam2.datasets import ReferSegmDataset
from projects.llava_sam2.models.preprocess.image_resize import DirectResize
#######################################################################
# PART 1 Settings #
#######################################################################
# Model
path = './pretrained/InternVL2_5-8B'
pretrained_pth = None
# Data
template = "internlm2_chat"
prompt_template = PROMPT_TEMPLATE.internlm2_chat
max_length = 8192
# Scheduler & Optimizer
batch_size = 2 # per_device
accumulative_counts = 4
dataloader_num_workers = 4
max_epochs = 1
optim_type = AdamW
# official 1024 -> 4e-5
# lr = 1e-6
lr = 4e-5
betas = (0.9, 0.999)
weight_decay = 0.05
max_norm = 1 # grad clip
warmup_ratio = 0.05
# Save
save_steps = 2000
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
special_tokens = ['[SEG]', '<p>', '</p>', '<vp>', '</vp>']
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=path,
trust_remote_code=True,
padding_side='right')
extra_image_processor = dict(
type=DirectResize,
target_length=1024,
)
#######################################################################
# PART 2 Model & Tokenizer & Image Processor #
#######################################################################
model = dict(
type=VideoLLaVASAMModel_zero3,
special_tokens=special_tokens,
frozen_sam2_decoder=False,
mllm=dict(
type=InternVL_Slowfast,
model_path=path,
freeze_llm=True,
freeze_visual_encoder=True,
llm_lora=dict(
type=LoraConfig,
r=128,
lora_alpha=256,
lora_dropout=0.05,
bias='none',
task_type='CAUSAL_LM'),
special_tokens=special_tokens,
),
tokenizer=tokenizer,
grounding_encoder=dict(
type=SAM2TrainRunner,
),
loss_mask=dict(
type=CrossEntropyLoss,
use_sigmoid=True,
reduction='mean',
loss_weight=2.0),
loss_dice=dict(
type=DiceLoss,
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=True,
eps=1.0,
loss_weight=0.5),
pretrained_pth=pretrained_pth,
loss_sample_points=True,
# loss_sample_points=False,
bs=batch_size,
)
#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
DATA_ROOT = './data/'
VIDEO_DATA_ROOT = DATA_ROOT + 'video_datas/'
############### video res
data_root_revos = VIDEO_DATA_ROOT + 'revos/'
video_revos_image_folder = data_root_revos
video_revos_expression_file = data_root_revos + 'meta_expressions_train_.json'
video_revos_mask_file = data_root_revos + 'mask_dict.json'
data_root_mevis = VIDEO_DATA_ROOT + 'mevis/train/'
video_mevis_image_folder = data_root_mevis + 'JPEGImages'
video_mevis_expression_file = data_root_mevis + 'meta_expressions.json'
video_mevis_mask_file = data_root_mevis + 'mask_dict.json'
data_root_refytvos = VIDEO_DATA_ROOT + 'rvos/'
video_refytvos_image_folder = data_root_refytvos + 'train/JPEGImages/'
video_refytvos_expression_file = data_root_refytvos + 'meta_expressions/train/meta_expressions.json'
video_refytvos_mask_file = data_root_refytvos + 'mask_dict.pkl'
video_revos_dataset = dict(
type=VideoReVOSDataset,
image_folder=video_revos_image_folder,
expression_file=video_revos_expression_file,
mask_file=video_revos_mask_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=10,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
sampled_frames=5,
)
video_mevis_dataset = dict(
type=VideoMeVISDataset,
image_folder=video_mevis_image_folder,
expression_file=video_mevis_expression_file,
mask_file=video_mevis_mask_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=4,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
sampled_frames=5,
)
video_refytvos_dataset = dict(
type=VideoRefYoutubeVOSDataset,
image_folder=video_refytvos_image_folder,
expression_file=video_refytvos_expression_file,
mask_file=video_refytvos_mask_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=4,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
sampled_frames=5,
)
################### Video chat
data_root_video_chatunivi = VIDEO_DATA_ROOT + 'chat_univi/'
video_chatunivi_image_folder = data_root_video_chatunivi + 'Activity_Videos/'
video_chatunivi_json_file = data_root_video_chatunivi+ 'video_chat.json'
video_qa_dataset = dict(
type=VideoChatUniViDataset,
image_folder=video_chatunivi_image_folder,
json_file=video_chatunivi_json_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
sampled_frames=5,
)
################## image chat
LLAVA_ROOT = DATA_ROOT + 'llava_data/'
llava_vqa_dataset = dict(
type=LLaVADataset,
tokenizer=tokenizer,
data_path=LLAVA_ROOT + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json',
prompt_template=prompt_template,
special_tokens=special_tokens,
image_folder=LLAVA_ROOT + 'llava_images/',
)
################## image res
RES_ROOT = DATA_ROOT + 'ref_seg/'
refcoco_segm_dataset=dict(
type=ReferSegmDataset,
tokenizer=tokenizer,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
data_root=RES_ROOT + 'refcoco',
data_prefix=dict(img_path='coco2014/train2014/'),
ann_file='instances.json',
split_file='refs(unc).p',
prompt_template=prompt_template,
num_classes_per_sample=5,
max_length=max_length,
)
refcoco_plus_segm_dataset=dict(
type=ReferSegmDataset,
tokenizer=tokenizer,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
data_root=RES_ROOT + 'refcoco+',
data_prefix=dict(img_path='coco2014/train2014/'),
ann_file='instances.json',
split_file='refs(unc).p',
prompt_template=prompt_template,
num_classes_per_sample=5,
max_length=max_length,
)
refcocog_segm_dataset=dict(
type=ReferSegmDataset,
tokenizer=tokenizer,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
data_root= RES_ROOT + 'refcocog',
data_prefix=dict(img_path='coco2014/train2014/'),
ann_file='instances.json',
split_file='refs(umd).p',
prompt_template=prompt_template,
num_classes_per_sample=5,
max_length=max_length,
)
# image gcg datas
glamm_data_root = DATA_ROOT + 'glamm_data/'
refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
grandf_image_path = glamm_data_root + 'images/grandf/train/'
grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
psg_image_path = glamm_data_root + 'images/coco2017/'
psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
glamm_refcocog_dataset = dict(
type=RefCOCOgGCGDataset,
image_folder=refcocog_image_path,
data_path=refcocog_ann_file,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
extra_image_processor=extra_image_processor,
lazy=True,
repeats=1,
)
glamm_grandf_dataset = dict(
type=GranDfGCGDataset,
data_path=grandf_ann_file,
image_folder=grandf_image_path,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
extra_image_processor=extra_image_processor,
lazy=True,
repeats=10,
)
glamm_psg_dataset = dict(
type=OpenPsgGCGDataset,
data_path=psg_ann_file,
image_folder=psg_image_path,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
extra_image_processor=extra_image_processor,
lazy=True,
repeats=1,
)
glamm_flickr_dataset = dict(
type=FlickrGCGDataset,
data_path=flickr_ann_file,
image_folder=flickr_image_path,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
extra_image_processor=extra_image_processor,
lazy=True,
repeats=1,
)
# sam2 data
data_sam2_folder = VIDEO_DATA_ROOT + 'sam_v_full/'
data_sam2_expression_file = VIDEO_DATA_ROOT + 'sam_v_final_v3.json'
video_sam2_dataset = dict(
type=VideoSAM2Dataset,
sam2_folder=data_sam2_folder,
expression_file=data_sam2_expression_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=4,
special_tokens=special_tokens,
extra_image_processor=extra_image_processor,
sampled_frames=5,
select_number=5,
)
# osprey
OSPREY_ROOT = DATA_ROOT + "osprey-724k/"
data_osprey_file = OSPREY_ROOT + 'Osprey-724K/osprey_conversation.json'
data_osprey_image_folders = [
OSPREY_ROOT + 'coco/train2014/',
OSPREY_ROOT + 'coco/val2014/',
OSPREY_ROOT + 'coco/train2017/',
OSPREY_ROOT + 'coco/val2017/',
]
image_osprey_dataset = dict(
type=OspreyDataset,
image_folder=data_osprey_image_folders,
data_path=data_osprey_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
)
data_osprey_detail_description_file = OSPREY_ROOT + 'Osprey-724K/osprey_detail_description.json'
image_osprey_description_dataset = dict(
type=OspreyDescriptionDataset,
image_folder=data_osprey_image_folders,
data_path=data_osprey_detail_description_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
)
data_osprey_short_file = OSPREY_ROOT + 'Osprey-724K/osprey_short_form.json'
image_osprey_short_dataset = dict(
type=OspreyShortDescriptionDataset,
image_folder=data_osprey_image_folders,
data_path=data_osprey_short_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
)
data_osprey_part_file = OSPREY_ROOT + 'Osprey-724K/osprey_part_level.json'
image_osprey_part_dataset = dict(
type=OspreyDataset,
image_folder=data_osprey_image_folders,
data_path=data_osprey_part_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
)
data_osprey_positive_neg_file = OSPREY_ROOT + 'Osprey-724K/osprey_lvis_positive_negative.json'
image_osprey_positive_neg_dataset = dict(
type=OspreyDataset,
image_folder=data_osprey_image_folders,
data_path=data_osprey_positive_neg_file,
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
lazy=True,
repeats=1,
special_tokens=special_tokens,
)
train_dataset = dict(
type=ConcatDataset, datasets=[
# sem seg
# semantic_seg_ade20k_dataset,
# ref seg
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
# image qa
llava_vqa_dataset,
# video res
video_mevis_dataset, video_revos_dataset, video_refytvos_dataset,
# video chat
video_qa_dataset,
# sam2 pesudo
video_sam2_dataset,
# gcg data
glamm_psg_dataset,
glamm_grandf_dataset,
glamm_flickr_dataset,
glamm_refcocog_dataset,
# visual prompt
image_osprey_dataset, image_osprey_description_dataset,
image_osprey_part_dataset, image_osprey_short_dataset,
image_osprey_positive_neg_dataset,
]
)
train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=train_dataset,
sampler=dict(
type=LengthGroupedSampler,
length_property='modality_length',
per_device_batch_size=batch_size * accumulative_counts),
collate_fn=dict(type=video_lisa_collate_fn)
)
#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='bfloat16'
)
# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = [
dict(
type=LinearLR,
start_factor=1e-5,
by_epoch=True,
begin=0,
end=warmup_ratio * max_epochs,
convert_to_iter_based=True),
dict(
type=CosineAnnealingLR,
eta_min=0.0,
by_epoch=True,
begin=warmup_ratio * max_epochs,
end=max_epochs,
convert_to_iter_based=True)
]
# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
# dict(type=DatasetInfoHook, tokenizer=tokenizer),
]
# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 10 iterations.
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per `save_steps`.
checkpoint=dict(
type=CheckpointHook,
save_optimizer=False,
by_epoch=False,
interval=save_steps,
max_keep_ckpts=save_total_limit),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)
# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)
# set visualizer
visualizer = None
# set log level
log_level = 'INFO'
# load from which checkpoint
load_from = None
# whether to resume training from the loaded checkpoint
resume = False
# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)
# set log processor
log_processor = dict(by_epoch=False)
import logging
import os
from typing import Literal
import torch
from datasets import Dataset as HFDataset
from datasets import DatasetDict, load_from_disk
from mmengine import print_log
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
from xtuner.registry import BUILDER
from xtuner.dataset.huggingface import build_origin_dataset
import copy
from .encode_fn import video_lisa_encode_fn
import json
import cv2
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from decord import VideoReader, cpu
def _get_rawvideo_dec(video_path, select_frames=5):
if os.path.exists(video_path):
vreader = VideoReader(video_path, ctx=cpu(0))
elif os.path.exists(video_path.replace('mkv', 'mp4')):
vreader = VideoReader(video_path.replace('mkv', 'mp4'), ctx=cpu(0))
else:
print(video_path)
raise FileNotFoundError
fps = vreader.get_avg_fps()
f_start = 0
f_end = len(vreader) - 1
num_frames = f_end - f_start + 1
assert num_frames > 0, f'num_frames: {num_frames}, f_start: {f_start}, f_end: {f_end}, fps: {fps}, video_path: {video_path}'
# T x 3 x H x W
if num_frames <= select_frames:
sample_pos = range(f_start, f_end + 1)
else:
split_point = np.linspace(0, num_frames, num=select_frames+1, dtype=int)
sample_pos = [np.random.randint(split_point[i], split_point[i+1]) for i in range(select_frames)]
patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
return patch_images
class VideoChatUniViDataset(Dataset):
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
FAST_IMG_CONTEXT_TOKEN = '<FAST_IMG_CONTEXT>'
FAST_IMG_START_TOKEN = '<fast_img>'
FAST_IMG_END_TOKEN = '</fast_img>'
def __init__(self,
image_folder,
json_file,
extra_image_processor=None,
tokenizer=None,
sampled_frames=10,
offline_processed_text_folder=None,
template_map_fn=None,
max_length=2048,
lazy=True,
repeats=1,
special_tokens=None,
use_fast=False,
n_fast_images=50,
fast_pool_size=4,
arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
preprocessor=None,
):
assert lazy is True
self.tokenizer = BUILDER.build(tokenizer)
self.sampled_frames = sampled_frames
assert offline_processed_text_folder or (json_file and tokenizer)
self.lazy = lazy
self.max_length = max_length
self.template_map_fn = template_map_fn
if isinstance(self.template_map_fn, dict) and self.lazy:
_type = self.template_map_fn['type']
del self.template_map_fn['type']
self.template_map_fn = _type(**self.template_map_fn)
if offline_processed_text_folder and json_file:
print_log(
'Both `offline_processed_text_folder` and '
'`data_path` are set, and we load dataset from'
'`offline_processed_text_folder` '
f'({offline_processed_text_folder})',
logger='current',
level=logging.WARNING)
if offline_processed_text_folder is not None:
raise NotImplementedError
else:
json_datas = self.json_file_preprocess(json_file)
self.json_datas = json_datas
json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
if self.lazy:
self.text_data = build_origin_dataset(json_data, 'train')
else:
raise NotImplementedError
self.image_folder = image_folder
if extra_image_processor is not None:
self.extra_image_processor = BUILDER.build(extra_image_processor)
self.arch_type = arch_type
if self.arch_type == 'qwen':
self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
self.IMG_START_TOKEN = '<|vision_start|>'
self.IMG_END_TOKEN = '<|vision_end|>'
elif self.arch_type == 'llava':
self.IMG_CONTEXT_TOKEN = '<image>'
self.IMG_START_TOKEN = ''
self.IMG_END_TOKEN = ''
self.repeats = repeats
self._system = ''
self.downsample_ratio = 0.5
if self.arch_type == 'llava':
self.downsample_ratio = 1
self.image_size = 448
if self.arch_type == 'llava':
self.image_size = 336
patch_size = 14
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
if self.arch_type == 'qwen':
self.patch_token = 1
if preprocessor is None:
self.transformer = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
])
self.preprocessor = None
else:
self.transformer = None
self.preprocessor = BUILDER.build(preprocessor)
self.arch_type = arch_type
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.use_fast = use_fast
self.n_fast_images = n_fast_images
self.fast_pool_size = fast_pool_size
# for visualization debug
self.save_folder = './work_dirs/video_debug/'
self.cur_number = 0
print("Video Chat dataset, include {} items.".format(len(self.text_data)))
def __len__(self):
return len(self.text_data) * self.repeats
@property
def modality_length(self):
length_list = []
for data_dict in self.text_data:
cur_len = 10000
length_list.append(cur_len)
return length_list
def real_len(self):
return len(self.text_data)
def json_file_preprocess(self, json_file):
# prepare expression annotation files
with open(json_file, 'r') as f:
json_datas = json.load(f)
return json_datas
def dataset_map_fn(self, data_dict, select_k=5):
assert 'video' in data_dict
# video
video_file = data_dict['video']
video_file = os.path.join(self.image_folder, video_file)
images = _get_rawvideo_dec(video_file, select_frames=select_k)
if self.use_fast:
fast_images = _get_rawvideo_dec(video_file, select_frames=self.n_fast_images)
else:
fast_images = None
conversation = data_dict['conversations']
# prepare text
if self.use_fast:
text_dict = self.prepare_text(
select_k, conversation, num_image_tokens=self.patch_token,
n_fast_images=len(fast_images),
)
else:
text_dict = self.prepare_text(
select_k, conversation, num_image_tokens=self.patch_token,
)
ret = {'images': images, 'conversation': text_dict['conversation'], 'fast_images': fast_images}
return ret
def prepare_text(self, n_frames, conversation, num_image_tokens=256, n_fast_images=0):
if self.use_fast:
fast_frame_token_str = f'{self.FAST_IMG_START_TOKEN}' \
f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_images * self.fast_pool_size * self.fast_pool_size}' \
f'{self.FAST_IMG_END_TOKEN}' + '\n'
else:
fast_frame_token_str = ''
frame_token_str = f'{self.IMG_START_TOKEN}' \
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
f'{self.IMG_END_TOKEN}'
questions = []
answers = []
for conv in conversation:
if conv['from'] == 'human':
questions.append(conv['value'].replace('<image>', ''))
else:
answers.append(conv['value'])
assert len(questions) == len(answers)
qa_list = []
for i, (question, answer) in enumerate(zip(questions, answers)):
if i == 0:
frame_tokens = frame_token_str + '\n'
# frame_tokens = '=' + ' '
frame_tokens = frame_tokens * n_frames
frame_tokens = frame_tokens.strip()
frame_tokens = fast_frame_token_str + frame_tokens
qa_list.append(
{'from': 'human', 'value': frame_tokens + question}
)
else:
qa_list.append(
{'from': 'human', 'value': question}
)
qa_list.append(
{'from': 'gpt', 'value': answer}
)
input = ''
conversation = []
for msg in qa_list:
if msg['from'] == 'human':
input += msg['value']
elif msg['from'] == 'gpt':
conversation.append({'input': input, 'output': msg['value']})
input = ''
else:
raise NotImplementedError
# add system information
conversation[0].update({'system': self._system})
return {'conversation': conversation}
def __getitem__(self, index):
index = index % self.real_len()
selected_data_dict = copy.deepcopy(self.text_data[index])
data_dict = self.dataset_map_fn(selected_data_dict, select_k=self.sampled_frames)
assert 'images' in data_dict.keys()
if self.use_fast:
assert 'fast_images' in data_dict.keys()
pixel_values = []
num_video_tokens = None
num_frame_tokens = None
if data_dict.get('images', None) is not None:
frames_files = data_dict['images']
for frame_image in frames_files:
frame_image = frame_image.convert('RGB')
ori_width, ori_height = frame_image.size
if self.preprocessor is not None:
pass
else:
frame_image = self.transformer(frame_image)
pixel_values.append(frame_image)
if self.preprocessor is not None:
if self.arch_type == 'qwen':
_data_dict = self.preprocessor(pixel_values, do_resize=True, size=(self.image_size, self.image_size))
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
_data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
num_frame_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
num_frames = _data_dict['image_grid_thw'].shape[0]
num_video_tokens = num_frame_tokens * num_frames
elif self.arch_type == 'llava':
_data_dict = self.preprocessor(pixel_values, do_resize=True,
size=(self.image_size, self.image_size))
_data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
else:
raise NotImplementedError
data_dict.update(_data_dict)
else:
pixel_values = torch.stack(pixel_values, dim=0) # (n_f, 3, h, w)
data_dict['pixel_values'] = pixel_values
else:
data_dict['pixel_values'] = torch.zeros(0, 3, self.image_size, self.image_size)
data_dict['masks'] = None
if num_video_tokens is not None:
assert self.patch_token == 1
input_str = data_dict['conversation'][0]['input']
input_str = input_str.replace(self.IMG_CONTEXT_TOKEN, self.IMG_CONTEXT_TOKEN * num_frame_tokens)
assert input_str.count(self.IMG_CONTEXT_TOKEN) == num_video_tokens
data_dict['conversation'][0]['input'] = input_str
result = self.template_map_fn(data_dict)
data_dict.update(result)
result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
data_dict.update(result)
# for fast branch
if self.use_fast:
fast_pixel_values = []
frames_files = data_dict['fast_images']
for frame_image in frames_files:
frame_image = frame_image.convert('RGB')
ori_width, ori_height = frame_image.size
frame_image = self.transformer(frame_image)
fast_pixel_values.append(frame_image)
fast_pixel_values = torch.stack(fast_pixel_values, dim=0) # (n_f, 3, h, w)
data_dict['fast_pixel_values'] = fast_pixel_values
# # for debug
# self.visualization_debug(data_dict)
# if self.cur_number < 10:
# return self[random.randint(0, len(self))]
data_dict['type'] = 'video'
return data_dict
def visualization_debug(self, data_dict):
save_folder = os.path.join(self.save_folder, 'sample_{}'.format(self.cur_number))
if not os.path.exists(save_folder):
os.mkdir(save_folder)
self.cur_number += 1
# images
show_images = []
pixel_values = data_dict['pixel_values']
save_folder_image = os.path.join(save_folder, 'image')
if not os.path.exists(save_folder_image):
os.mkdir(save_folder_image)
for i_image, image_pixel_value in enumerate(pixel_values):
# print(image_pixel_value.shape)
image_pixel_value[0] = image_pixel_value[0] * 0.2686
image_pixel_value[1] = image_pixel_value[1] * 0.2613
image_pixel_value[2] = image_pixel_value[2] * 0.2757
image_pixel_value[0] = image_pixel_value[0] + 0.4814
image_pixel_value[1] = image_pixel_value[1] + 0.4578
image_pixel_value[2] = image_pixel_value[2] + 0.4082
image_pixel_value = image_pixel_value * 255
image_pixel_value = image_pixel_value.permute(1, 2, 0)
image_pixel_value = image_pixel_value.to(torch.uint8).numpy()
# print(os.path.join(save_folder_image, '{}.jpg'.format(i_image)))
# print(image_pixel_value.shape)
show_images.append(image_pixel_value)
cv2.imwrite(os.path.join(save_folder_image, '{}.jpg'.format(i_image)), image_pixel_value)
# text
input_text = self.tokenizer.decode(data_dict['input_ids'], skip_special_tokens=False)
with open(os.path.join(save_folder, 'text.json'), 'w') as f:
json.dump([input_text], f)
return
import json
import os
import torch
from datasets import Dataset as HFDataset
from datasets import DatasetDict, load_from_disk
from PIL import Image
from torch.utils.data import Dataset
from pycocotools import mask
import numpy as np
import copy
from xtuner.registry import BUILDER
from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
import torchvision.transforms as T
from xtuner.utils import DEFAULT_IMAGE_TOKEN
from torchvision.transforms.functional import InterpolationMode
from .encode_fn import video_lisa_encode_fn
from .utils import dynamic_preprocess
from .gcg_process import glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn, glamm_refcocog_map_fn
class GCGDataset(Dataset):
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def __init__(self,
image_folder,
data_path=None,
tokenizer=None,
max_length=8196,
special_tokens=None,
template_map_fn=None,
extra_image_processor=None,
lazy=True,
repeats=1,
single_image_mode=False,
):
super().__init__()
assert lazy
self.lazy = lazy
self.max_length = max_length
json_data = self.json_file_preprocess(data_path)
json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
self.text_data = build_origin_dataset(json_data, 'train')
self.image_folder = image_folder
self.tokenizer = BUILDER.build(tokenizer)
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.template_map_fn = template_map_fn
if isinstance(self.template_map_fn, dict) and self.lazy:
_type = self.template_map_fn['type']
del self.template_map_fn['type']
self.template_map_fn = _type(**self.template_map_fn)
if extra_image_processor is not None:
self.extra_image_processor = BUILDER.build(extra_image_processor)
self.repeats = repeats
self._system = ''
self.min_dynamic_patch = 1
self.max_dynamic_patch = 12
self.downsample_ratio = 0.5
self.image_size = 448
self.use_thumbnail = True
patch_size = 14
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
self.transformer = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
])
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.single_image_mode = single_image_mode
def json_file_preprocess(self, data_path):
with open(data_path, 'r') as f:
json_data = json.load(f)
return json_data
@property
def modality_length(self):
length_list = []
for data_dict in self.text_data:
if self.lazy:
cur_len = 100
else:
cur_len = len(data_dict['input_ids'])
if data_dict.get('image', None) is None:
cur_len = -cur_len
length_list.append(cur_len)
return length_list * self.repeats
def __len__(self):
return len(self.text_data) * self.repeats
def real_len(self):
return len(self.text_data)
def decode_mask(self, object_masks, ori_height, ori_width):
binary_masks = []
for object_mask in object_masks:
binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
for seg in object_mask:
rles = mask.frPyObjects([seg], ori_height, ori_width)
m = mask.decode(rles)
m = m.astype(np.uint8)
binary_mask += m.squeeze()
binary_masks.append(binary_mask)
if len(binary_masks) == 0:
return None
masks = np.stack(binary_masks, axis=0)
masks = torch.from_numpy(masks)
return masks
def dataset_map_fn(self, data_dict):
data_dict = glamm_refcocog_map_fn(data_dict)
return data_dict
def replace_image_str(self, data_dict, image_str):
data_dict['conversation'][0]['input'] = \
data_dict['conversation'][0]['input'].replace(DEFAULT_IMAGE_TOKEN, image_str)
return data_dict
def __getitem__(self, index):
index = index % self.real_len()
data_dict = copy.deepcopy(self.text_data[index])
# parse datasets
result = self.dataset_map_fn(data_dict)
data_dict.update(result)
# process image
image_file = data_dict['image']
image = Image.open(os.path.join(self.image_folder,
image_file)).convert('RGB')
ori_width, ori_height = image.size
if hasattr(self, 'extra_image_processor'):
g_image = np.array(image) # for grounding
g_image = self.extra_image_processor.apply_image(g_image)
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
data_dict['g_pixel_values'] = g_pixel_values
if self.single_image_mode:
images = [image]
else:
images = dynamic_preprocess(image, self.min_dynamic_patch,
self.max_dynamic_patch,
self.image_size, self.use_thumbnail)
pixel_values = [self.transformer(image) for image in images]
pixel_values = torch.stack(pixel_values)
data_dict['pixel_values'] = pixel_values
num_image_tokens = pixel_values.shape[0] * self.patch_token
image_token_str = f'{self.IMG_START_TOKEN}' \
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
f'{self.IMG_END_TOKEN}'
data_dict = self.replace_image_str(data_dict, image_token_str)
result = self.template_map_fn(data_dict)
data_dict.update(result)
result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length,
with_image_token=True)
data_dict.update(result)
# process mask
data_dict['masks'] = self.decode_mask(data_dict['masks'], ori_height=ori_height, ori_width=ori_width)
if data_dict['masks'] is None:
return self.__getitem__(0)
return data_dict
class RefCOCOgGCGDataset(GCGDataset):
def __init__(self,
image_folder,
data_path=None,
tokenizer=None,
max_length=8196,
special_tokens=None,
template_map_fn=None,
extra_image_processor=None,
lazy=True,
repeats=1,
single_image_mode=False,
):
super().__init__(
image_folder=image_folder,
data_path=data_path,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=template_map_fn,
extra_image_processor=extra_image_processor,
lazy=lazy,
repeats=repeats,
single_image_mode=single_image_mode,
)
def json_file_preprocess(self, data_path):
json_data = json.load(open(data_path))
# convert {id: dict} to dict(..., id=xx)
for idx in range(len(json_data)):
id = list(json_data[idx].keys())[0]
json_data[idx] = json_data[idx][id]
json_data[idx].update({'id': id})
return json_data
class GranDfGCGDataset(GCGDataset):
def __init__(self,
image_folder,
data_path=None,
tokenizer=None,
max_length=8196,
special_tokens=None,
template_map_fn=None,
extra_image_processor=None,
lazy=True,
repeats=1,
single_image_mode=False,
):
super().__init__(
image_folder=image_folder,
data_path=data_path,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=template_map_fn,
extra_image_processor=extra_image_processor,
lazy=lazy,
repeats=repeats,
single_image_mode=single_image_mode,
)
def dataset_map_fn(self, data_dict):
data_dict = glamm_granf_map_fn(data_dict)
return data_dict
def decode_mask(self, object_masks, ori_height, ori_width):
binary_masks = []
for object_mask in object_masks:
binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
for rle in object_mask:
m = mask.decode(rle).astype(np.uint8)
binary_mask += m.squeeze()
binary_masks.append(binary_mask)
if len(binary_masks) == 0:
return None
masks = np.stack(binary_masks, axis=0)
masks = torch.from_numpy(masks)
return masks
class OpenPsgGCGDataset(GranDfGCGDataset):
def __init__(self,
image_folder,
data_path=None,
tokenizer=None,
max_length=8196,
special_tokens=None,
template_map_fn=None,
extra_image_processor=None,
lazy=True,
repeats=1,
single_image_mode=False,
):
super().__init__(
image_folder=image_folder,
data_path=data_path,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=template_map_fn,
extra_image_processor=extra_image_processor,
lazy=lazy,
repeats=repeats,
single_image_mode=single_image_mode,
)
def dataset_map_fn(self, data_dict):
data_dict = glamm_openpsg_map_fn(data_dict)
return data_dict
class FlickrGCGDataset(GCGDataset):
def __init__(self,
image_folder,
data_path=None,
tokenizer=None,
max_length=8196,
special_tokens=None,
template_map_fn=None,
extra_image_processor=None,
lazy=True,
repeats=1,
single_image_mode=False,
):
super().__init__(
image_folder=image_folder,
data_path=data_path,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=template_map_fn,
extra_image_processor=extra_image_processor,
lazy=lazy,
repeats=repeats,
single_image_mode=single_image_mode,
)
def dataset_map_fn(self, data_dict):
data_dict = glamm_flickr_map_fn(data_dict)
return data_dict
def json_file_preprocess(self, data_path):
def filter_images(data_infos, min_size):
return [i for i, info in enumerate(data_infos) if min(info['width'], info['height']) >= min_size]
# convert {id: dict} to dict(..., id=xx)
from pycocotools.coco import COCO
self.coco = COCO(data_path)
self.image_ids = self.coco.getImgIds()
data_infos = []
total_ann_ids = []
removed_img_count = 0
for img_id in self.image_ids:
info = self.coco.loadImgs([img_id])[0]
if len(info['caption'].split(' ')) < 3:
removed_img_count += 1
continue
info['filename'] = info['file_name'].split('_')[-1]
info['height'] = int(info['height'])
info['width'] = int(info['width'])
data_infos.append(info)
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
total_ann_ids.extend(ann_ids)
assert len(set(total_ann_ids)) == len(total_ann_ids), f"Non-unique annotation IDs in '{data_path}'!"
print(f'Removed {removed_img_count} images.')
data_infos = [data_infos[i] for i in filter_images(data_infos, min_size=32)]
# obtain_annotations
for data_info in data_infos:
ann_ids = self.coco.getAnnIds(imgIds=data_info['id'])
ann_info = self.coco.loadAnns(ann_ids)
data_info.update({'ann_info': ann_info})
return data_infos
def decode_mask(self, object_masks, ori_height, ori_width):
binary_masks = []
for object_mask in object_masks:
binary_mask = mask.decode(object_mask).astype(np.uint8)
binary_masks.append(binary_mask)
if len(binary_masks) == 0:
return None
masks = np.stack(binary_masks, axis=0)
masks = torch.from_numpy(masks)
return masks
\ No newline at end of file
import json
import os
import random
import torch
from datasets import Dataset as HFDataset
from datasets import DatasetDict, load_from_disk
from PIL import Image
from torch.utils.data import Dataset
from pycocotools import mask
import numpy as np
import copy
from xtuner.registry import BUILDER
from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
import torchvision.transforms as T
from xtuner.utils import DEFAULT_IMAGE_TOKEN
from torchvision.transforms.functional import InterpolationMode
from .encode_fn import video_lisa_encode_fn
from .utils import dynamic_preprocess
from .grand_process import glamm_grand_map_fn
class GranDDataset(Dataset):
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def __init__(self,
image_folder,
json_folder=None,
tokenizer=None,
max_length=8196,
special_tokens=None,
template_map_fn=None,
extra_image_processor=None,
lazy=True,
repeats=1,
single_image_mode=False,
image_list_save_path='./work_dirs/grand_image.json',
json_list_save_path='./work_dirs/grand_jsons.json',
):
super().__init__()
assert lazy
self.lazy = lazy
self.max_length = max_length
self.image_list_save_path = image_list_save_path
self.json_list_save_path = json_list_save_path
json_files, image_path_dict = self.json_file_preprocess(image_folder, json_folder)
self.json_data = json_files
self.image_path_dict = image_path_dict
self.image_folder = image_folder
self.tokenizer = BUILDER.build(tokenizer)
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.template_map_fn = template_map_fn
if isinstance(self.template_map_fn, dict) and self.lazy:
_type = self.template_map_fn['type']
del self.template_map_fn['type']
self.template_map_fn = _type(**self.template_map_fn)
if extra_image_processor is not None:
self.extra_image_processor = BUILDER.build(extra_image_processor)
self.repeats = repeats
self._system = ''
self.min_dynamic_patch = 1
self.max_dynamic_patch = 12
self.downsample_ratio = 0.5
self.image_size = 448
self.use_thumbnail = True
patch_size = 14
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
self.transformer = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
])
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.single_image_mode = single_image_mode
def json_file_preprocess(self, image_folder, json_folder):
# list jsons
print("Processing GRAND json files !!!")
if os.path.exists(self.json_list_save_path):
with open(self.json_list_save_path, 'r') as f:
json_files = json.load(f)
else:
json_files = os.listdir(json_folder)
_json_files = []
for _file in json_files:
if '.json' in _file:
_json_files.append(os.path.join(json_folder, _file))
json_files = _json_files
with open(self.json_list_save_path, 'w') as f:
json.dump(json_files, f)
print(f"Finished, {len(json_files)} json files !")
# list images
print("Processing GRAND image files !!!")
if os.path.exists(self.image_list_save_path):
with open(self.image_list_save_path, 'r') as f:
image_path_dict = json.load(f)
else:
sub_folders = os.listdir(image_folder)
_sub_folders = []
for folder_name in sub_folders:
if 'sa_00' in folder_name:
_sub_folders.append(folder_name)
sub_folders = _sub_folders
sub_folders = [os.path.join(image_folder, folder_name) for folder_name in sub_folders]
image_path_dict = {}
for sub_folder in sub_folders:
files = os.listdir(sub_folder)
for _file in files:
if '.jpg' in _file:
image_path_dict[_file] = os.path.join(sub_folder, _file)
with open(self.image_list_save_path, 'w') as f:
json.dump(image_path_dict, f)
print(f"Finished, {len(image_path_dict)} image files !")
return json_files, image_path_dict
@property
def modality_length(self):
length_list = [10000] * len(self.json_data)
return length_list * self.repeats
def __len__(self):
return len(self.json_data) * self.repeats
def real_len(self):
return len(self.json_data)
def decode_mask(self, object_masks, ori_height, ori_width):
binary_masks = []
for object_mask in object_masks:
binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
for seg in object_mask:
m = mask.decode(seg)
m = m.astype(np.uint8)
binary_mask += m.squeeze()
binary_masks.append(binary_mask)
if len(binary_masks) == 0:
return None
masks = np.stack(binary_masks, axis=0)
masks = torch.from_numpy(masks)
return masks
def dataset_map_fn(self, data_dict):
data_dict = glamm_grand_map_fn(data_dict)
return data_dict
def replace_image_str(self, data_dict, image_str):
data_dict['conversation'][0]['input'] = \
data_dict['conversation'][0]['input'].replace(DEFAULT_IMAGE_TOKEN, image_str)
return data_dict
def __getitem__(self, index):
index = index % self.real_len()
json_file_path = self.json_data[index]
with open(json_file_path, 'r') as f:
json_dict = json.load(f)
image_name = list(json_dict.keys())[0]
if image_name not in self.image_path_dict.keys():
return self.__getitem__(random.randint(0, len(self.json_data) - 1))
image_path = self.image_path_dict[image_name]
json_dict = json_dict[image_name]
# parse datasets
result = self.dataset_map_fn(json_dict)
json_dict.update(result)
data_dict = json_dict
data_dict['image'] = image_path
# process image
image_file = data_dict['image']
try:
image = Image.open(os.path.join(self.image_folder,
image_file)).convert('RGB')
except:
return self.__getitem__(random.randint(0, len(self.json_data) - 1))
ori_width, ori_height = image.size
if hasattr(self, 'extra_image_processor'):
g_image = np.array(image) # for grounding
g_image = self.extra_image_processor.apply_image(g_image)
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
data_dict['g_pixel_values'] = g_pixel_values
if self.single_image_mode:
images = [image]
else:
images = dynamic_preprocess(image, self.min_dynamic_patch,
self.max_dynamic_patch,
self.image_size, self.use_thumbnail)
pixel_values = [self.transformer(image) for image in images]
pixel_values = torch.stack(pixel_values)
data_dict['pixel_values'] = pixel_values
num_image_tokens = pixel_values.shape[0] * self.patch_token
image_token_str = f'{self.IMG_START_TOKEN}' \
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
f'{self.IMG_END_TOKEN}'
data_dict = self.replace_image_str(data_dict, image_token_str)
result = self.template_map_fn(data_dict)
data_dict.update(result)
result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length,
with_image_token=True)
data_dict.update(result)
# process mask
data_dict['masks'] = self.decode_mask(data_dict['masks'], ori_height=ori_height, ori_width=ori_width)
if data_dict['masks'] is None:
return self.__getitem__(random.randint(0, len(self.json_data) - 1))
return data_dict
\ No newline at end of file
from .ReVOS_Dataset import VideoReVOSDataset
class VideoMeVISDataset(VideoReVOSDataset):
pass
import json
import os
import torch
from datasets import Dataset as HFDataset
from datasets import DatasetDict, load_from_disk
from PIL import Image
from torch.utils.data import Dataset
from pycocotools import mask as maskUtils
import numpy as np
import copy
from xtuner.registry import BUILDER
from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
import torchvision.transforms as T
from xtuner.utils import DEFAULT_IMAGE_TOKEN
from torchvision.transforms.functional import InterpolationMode
from .encode_fn import video_lisa_encode_fn
from .utils import dynamic_preprocess
import random
import torch.nn.functional as F
class OspreyDataset(Dataset):
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
LIMIT = ''
VP_START_TOKEN = '<vp>'
VP_END_TOKEN = '</vp>'
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def __init__(self,
image_folder,
data_path=None,
tokenizer=None,
max_length=8196,
special_tokens=None,
template_map_fn=None,
extra_image_processor=None,
lazy=True,
repeats=1,
single_image_mode=False,
):
super().__init__()
assert lazy
self.lazy = lazy
self.max_length = max_length
json_data = self.json_file_preprocess(data_path)
self.text_data = json_data
self.image_folder = image_folder
self.tokenizer = BUILDER.build(tokenizer)
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.template_map_fn = template_map_fn
if isinstance(self.template_map_fn, dict) and self.lazy:
_type = self.template_map_fn['type']
del self.template_map_fn['type']
self.template_map_fn = _type(**self.template_map_fn)
if extra_image_processor is not None:
self.extra_image_processor = BUILDER.build(extra_image_processor)
self.repeats = repeats
self._system = ''
self.min_dynamic_patch = 1
self.max_dynamic_patch = 12
self.downsample_ratio = 0.5
self.image_size = 448
self.use_thumbnail = True
patch_size = 14
self.patch_size = patch_size
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
self.transformer = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
])
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.single_image_mode = single_image_mode
def json_file_preprocess(self, data_path):
with open(data_path, 'r') as f:
json_data = json.load(f)
return json_data
@property
def modality_length(self):
length_list = []
for data_dict in self.text_data:
if self.lazy:
cur_len = 100
else:
cur_len = len(data_dict['input_ids'])
if data_dict.get('image', None) is None:
cur_len = -cur_len
length_list.append(cur_len)
return length_list * self.repeats
def __len__(self):
return len(self.text_data) * self.repeats
def real_len(self):
return len(self.text_data)
def annToMask(self, mask_ann, h, w):
if isinstance(mask_ann, list):
rles = maskUtils.frPyObjects(mask_ann, h, w)
rle = maskUtils.merge(rles)
elif isinstance(mask_ann['counts'], list):
# uncompressed RLE
rle = maskUtils.frPyObjects(mask_ann, h, w)
else:
# rle
rle = mask_ann
mask = maskUtils.decode(rle)
return mask
def decode_mask(self, object_masks, ori_height, ori_width):
binary_masks = []
for object_mask in object_masks:
binary_mask = self.annToMask(object_mask, ori_height, ori_width)
binary_masks.append(binary_mask)
if len(binary_masks) == 0:
return None
masks = np.stack(binary_masks, axis=0)
masks = torch.from_numpy(masks)
return masks
def _process_conversation(self, converations, n_regions, region_pixels):
start_region_str = '<image> There are {} part regions in the picture: '.format(n_regions)
for i in range(n_regions):
start_region_str = start_region_str + \
f"region{i+1}" + self.VP_START_TOKEN + self.IMG_CONTEXT_TOKEN * region_pixels[i] + self.VP_END_TOKEN
if i == n_regions - 1:
start_region_str = start_region_str + '.\n'
else:
start_region_str = start_region_str + ', '
for i, item in enumerate(converations):
item['value'] = item['value'].replace('<', '').replace('>', '')
if item['from'] == 'human':
item['value'] = item['value'] + self.LIMIT
# first conv process
if i == 0:
assert item['from'] == "human"
item['value'] = start_region_str + item['value']
messages = converations
input = ''
conversation = []
while messages and messages[0]['from'] == 'gpt':
# Skip the first one if it is from gpt
messages = messages[1:]
for msg in messages:
if msg['from'] == 'human':
if DEFAULT_IMAGE_TOKEN in msg['value']:
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
'').strip()
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
msg['value'] = msg['value'].strip()
input += msg['value']
elif msg['from'] == 'gpt':
conversation.append({'input': input, 'output': msg['value']})
input = ''
else:
raise NotImplementedError
return conversation
def _get_region_infos(self, masks):
# masks tensor, (n_obj, h, w)
masks = F.interpolate(
masks.unsqueeze(0),
size=(int(self.image_size // self.patch_size * self.downsample_ratio),
int(self.image_size // self.patch_size * self.downsample_ratio)),
mode='nearest').squeeze(0)
region_pixels = []
for mask in masks:
region_pixels.append(mask.bool().to(torch.int64).sum())
return masks, region_pixels
def dataset_map_fn(self, data_dict):
file_name = data_dict['file_name'] # image file name
conversations = data_dict['conversations']
masks = [anno["segmentation"] for anno in data_dict["annotation"]]
height = data_dict['height']
width = data_dict['width']
_ret = {}
_ret['image'] = file_name
_ret['height'] = height
_ret['width'] = width
masks = self.decode_mask(masks, height, width)
masks, region_pixels = self._get_region_infos(masks)
if masks is None:
return None
conversations = self._process_conversation(conversations, len(masks), region_pixels)
_ret['conversation'] = conversations
_ret['prompt_masks'] = masks
return _ret
def replace_image_str(self, data_dict, image_str):
data_dict['conversation'][0]['input'] = \
data_dict['conversation'][0]['input'].replace(DEFAULT_IMAGE_TOKEN, image_str)
return data_dict
def __getitem__(self, index):
index = index % self.real_len()
data_dict = copy.deepcopy(self.text_data[index])
# parse datasets
result = self.dataset_map_fn(data_dict) # {'image', 'height', 'width', 'conversation', 'masks'}
if result is None or result['prompt_masks'] is None:
return self.__getitem__(0)
data_dict = result
# process image
image_file = data_dict['image']
if isinstance(self.image_folder, list):
for image_folder in self.image_folder:
image_path = os.path.join(image_folder, image_file)
if os.path.exists(image_path):
image = Image.open(image_path).convert('RGB')
break
else:
image = Image.open(os.path.join(self.image_folder,
image_file)).convert('RGB')
ori_width, ori_height = image.size
if self.single_image_mode:
images = [image]
else:
images = dynamic_preprocess(image, self.min_dynamic_patch,
self.max_dynamic_patch,
self.image_size, self.use_thumbnail)
vp_overall_mask = torch.Tensor([False] * (len(images) - 1) + [True])
data_dict['vp_overall_mask'] = vp_overall_mask
pixel_values = [self.transformer(image) for image in images]
pixel_values = torch.stack(pixel_values)
data_dict['pixel_values'] = pixel_values
num_image_tokens = pixel_values.shape[0] * self.patch_token
image_token_str = f'{self.IMG_START_TOKEN}' \
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
f'{self.IMG_END_TOKEN}'
data_dict = self.replace_image_str(data_dict, image_token_str)
result = self.template_map_fn(data_dict)
data_dict.update(result)
result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length,
with_image_token=True)
data_dict.update(result)
# process mask
# data_dict['prompt_masks'] = data_dict['prompt_masks']
if data_dict['prompt_masks'] is None:
return self.__getitem__(0)
return data_dict
DETAILED_QUESTIONS = [
'Can you provide me with a detailed description of the region in the picture marked by <region>?',
"I'm curious about the region represented by <region> in the picture. Could you describe it in detail?",
'What can you tell me about the region indicated by <region> in the image?',
"I'd like to know more about the area in the photo labeled <region>. Can you give me a detailed description?",
'Could you describe the region shown as <region> in the picture in great detail?',
'What details can you give me about the region outlined by <region> in the photo?',
'Please provide me with a comprehensive description of the region marked with <region> in the image.',
'Can you give me a detailed account of the region labeled as <region> in the picture?',
"I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail?",
'What is the region outlined by <region> in the picture like? Could you give me a detailed description?',
'Can you provide me with a detailed description of the region in the picture marked by <region>, please?',
"I'm curious about the region represented by <region> in the picture. Could you describe it in detail, please?",
'What can you tell me about the region indicated by <region> in the image, exactly?',
"I'd like to know more about the area in the photo labeled <region>, please. Can you give me a detailed description?",
'Could you describe the region shown as <region> in the picture in great detail, please?',
'What details can you give me about the region outlined by <region> in the photo, please?',
'Please provide me with a comprehensive description of the region marked with <region> in the image, please.',
'Can you give me a detailed account of the region labeled as <region> in the picture, please?',
"I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail, please?",
'What is the region outlined by <region> in the picture like, please? Could you give me a detailed description?',
'Please describe the region <region> in the image in detail.',
'Can you offer a thorough analysis of the region <region> in the image?',
'Could you elaborate on the region highlighted by <region> in the picture provided?',
'Please share more information about the zone emphasized with <region> in the photo.',
'What insights can you give about the area denoted by <region> in the image presented?',
'Can you share a comprehensive rundown of the region denoted by <region> in the presented image?',
"I'd like to know more about the region highlighted by <region> in the picture provided.",
'Work through the important details of the area <region> in the image.',
'Illustrate the area represented by <region> through a descriptive explanation.',
'Examine the region <region> closely and share its details.'
]
class OspreyDescriptionDataset(OspreyDataset):
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
VP_START_TOKEN = '<vp>'
VP_END_TOKEN = '</vp>'
LIMIT=''
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def __init__(self,
image_folder,
data_path=None,
tokenizer=None,
max_length=8196,
special_tokens=None,
template_map_fn=None,
extra_image_processor=None,
lazy=True,
repeats=1,
single_image_mode=False,
):
super(OspreyDescriptionDataset, self).__init__(
image_folder=image_folder,
data_path=data_path,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=template_map_fn,
extra_image_processor=extra_image_processor,
lazy=lazy,
repeats=repeats,
single_image_mode=single_image_mode,
)
def dataset_map_fn(self, data_dict):
file_name = data_dict['file_name'] # image file name
descriptions = data_dict['description']
masks = [anno["segmentation"] for anno in data_dict["annotation"]]
height = data_dict['height']
width = data_dict['width']
_ret = {}
_ret['image'] = file_name
_ret['height'] = height
_ret['width'] = width
masks = self.decode_mask(masks, height, width)
masks, region_pixels = self._get_region_infos(masks)
if masks is None:
return None
conversations = self._process_conversation(descriptions, len(masks), region_pixels)
_ret['conversation'] = conversations
_ret['prompt_masks'] = masks
return _ret
def _process_conversation(self, descriptions, n_regions, region_pixels):
start_region_str = '<image> There are {} part regions in the picture: '.format(n_regions)
for i in range(n_regions):
start_region_str = start_region_str + \
f"region{i+1}" + self.VP_START_TOKEN + self.IMG_CONTEXT_TOKEN * region_pixels[i] + self.VP_END_TOKEN
if i == n_regions - 1:
start_region_str = start_region_str + '.\n'
else:
start_region_str = start_region_str + ', '
converations = []
for i, item in enumerate(descriptions):
question = random.choice(DETAILED_QUESTIONS).strip().replace('<region>', f"region{i+1}") + self.LIMIT
answer = item.replace('<', '').replace('>', '')
# first conv process
if i == 0:
question = start_region_str + question
converations.append({'from': 'human', 'value': question})
converations.append({'from': 'gpt', 'value': answer})
messages = converations
input = ''
conversation = []
while messages and messages[0]['from'] == 'gpt':
# Skip the first one if it is from gpt
messages = messages[1:]
for msg in messages:
if msg['from'] == 'human':
if DEFAULT_IMAGE_TOKEN in msg['value']:
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
'').strip()
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
msg['value'] = msg['value'].strip()
input += msg['value']
elif msg['from'] == 'gpt':
conversation.append({'input': input, 'output': msg['value']})
input = ''
else:
raise NotImplementedError
return conversation
class OspreyShortDescriptionDataset(OspreyDataset):
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
VP_START_TOKEN = '<vp>'
VP_END_TOKEN = '</vp>'
LIMIT = ' Answer the question using a single word or phrase.'
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def __init__(self,
image_folder,
data_path=None,
tokenizer=None,
max_length=8196,
special_tokens=None,
template_map_fn=None,
extra_image_processor=None,
lazy=True,
repeats=1,
single_image_mode=False,
):
super(OspreyShortDescriptionDataset, self).__init__(
image_folder=image_folder,
data_path=data_path,
tokenizer=tokenizer,
max_length=max_length,
special_tokens=special_tokens,
template_map_fn=template_map_fn,
extra_image_processor=extra_image_processor,
lazy=lazy,
repeats=repeats,
single_image_mode=single_image_mode,
)
\ No newline at end of file
import logging
import os
import torch
from datasets import Dataset as HFDataset
from datasets import DatasetDict, load_from_disk
from mmengine import print_log
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
from xtuner.registry import BUILDER
from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
import copy
from .encode_fn import video_lisa_encode_fn
import json
import random
import pycocotools.mask as maskUtils
import cv2
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
SEG_QUESTIONS = [
"Please segment the object according to the description: {class_name}",
]
SEG_QUESTIONS_SHORT = [
"Can you segment the {class_name} in this image?",
"Please segment {class_name} in this image.",
"What is {class_name} in this image? Please respond with segmentation mask.",
"What is {class_name} in this image? Please output segmentation mask.",
"Can you segment the {class_name} in this image",
"Please segment {class_name} in this image",
"What is {class_name} in this image? Please respond with segmentation mask",
"What is {class_name} in this image? Please output segmentation mask",
"Could you provide a segmentation mask for the {class_name} in this image?",
"Please identify and segment the {class_name} in this image.",
"Where is the {class_name} in this picture? Please respond with a segmentation mask.",
"Can you highlight the {class_name} in this image with a segmentation mask?",
"Could you provide a segmentation mask for the {class_name} in this image",
"Please identify and segment the {class_name} in this image",
"Where is the {class_name} in this picture? Please respond with a segmentation mask",
"Can you highlight the {class_name} in this image with a segmentation mask",
]
ANSWER_LIST = [
"It is [SEG].",
"Sure, [SEG].",
"Sure, it is [SEG].",
"Sure, the segmentation result is [SEG].",
"[SEG].",
]
class VideoSAM2Dataset(Dataset):
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
FAST_IMG_CONTEXT_TOKEN = '<FAST_IMG_CONTEXT>'
FAST_IMG_START_TOKEN = '<fast_img>'
FAST_IMG_END_TOKEN = '</fast_img>'
def __init__(self,
sam2_folder,
expression_file,
extra_image_processor=None,
tokenizer=None,
select_number=5,
sampled_frames=5,
offline_processed_text_folder=None,
template_map_fn=None,
max_length=8196,
lazy=True,
repeats=1,
special_tokens=None,
use_fast=False,
n_fast_images=50,
fast_pool_size=4,
mode='long',
frame_contiguous_sample=False,
):
assert mode in ['long', 'long_short', 'short']
self.mode = mode
self.cur_mode = mode
assert lazy is True
self.tokenizer = BUILDER.build(tokenizer)
self.select_number = select_number
self.sampled_frames = sampled_frames
assert offline_processed_text_folder or (expression_file and tokenizer)
self.lazy = lazy
self.max_length = max_length
self.template_map_fn = template_map_fn
if isinstance(self.template_map_fn, dict) and self.lazy:
_type = self.template_map_fn['type']
del self.template_map_fn['type']
self.template_map_fn = _type(**self.template_map_fn)
if offline_processed_text_folder and expression_file:
print_log(
'Both `offline_processed_text_folder` and '
'`data_path` are set, and we load dataset from'
'`offline_processed_text_folder` '
f'({offline_processed_text_folder})',
logger='current',
level=logging.WARNING)
if offline_processed_text_folder is not None:
raise NotImplementedError
else:
video_ids, anno_dict = self.json_file_preprocess(expression_file)
if self.lazy:
self.video_ids = video_ids
self.anno_dict = anno_dict
else:
raise NotImplementedError
self.sam2_folder = sam2_folder
if extra_image_processor is not None:
self.extra_image_processor = BUILDER.build(extra_image_processor)
self.down_ratio = 1
self.repeats = repeats
self._system = ''
self.downsample_ratio = 0.5
self.image_size = 448
patch_size = 14
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
self.transformer = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
])
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.use_fast = use_fast
self.n_fast_images = n_fast_images
self.fast_pool_size = fast_pool_size
self.frame_contiguous_sample = frame_contiguous_sample
# for visualization debug
self.save_folder = './work_dirs/video_debug/'
self.cur_number = 0
print("Video res dataset (ref-sam2), include {} items.".format(len(self.video_ids)))
def __len__(self):
return len(self.video_ids) * self.repeats
@property
def modality_length(self):
length_list = []
for data_dict in self.video_ids:
cur_len = 20000
length_list.append(cur_len)
return length_list
def real_len(self):
return len(self.video_ids)
def json_file_preprocess(self, expression_file):
# prepare expression annotation files
with open(expression_file, 'r') as f:
expression_datas = json.load(f)
video_ids = list(expression_datas.keys())
return video_ids, expression_datas
def dataset_map_fn(self, objects_expression_infos, n_frames, n_fast_frames=0):
# prepare text
if self.mode == 'long':
expressions = [object_info['formated'] for object_info in objects_expression_infos]
self.cur_mode = self.mode
elif self.mode == 'short':
expressions = [object_info['short_caps'][random.randint(0, len(object_info['short_caps'])-1)] for object_info in objects_expression_infos]
self.cur_mode = self.mode
else:
if random.random() < 0.5:
expressions = [object_info['formated'] for object_info in objects_expression_infos]
self.cur_mode = 'long'
else:
expressions = [object_info['short_caps'][random.randint(0, len(object_info['short_caps']) - 1)] for
object_info in objects_expression_infos]
self.cur_mode = 'short'
text_dict = self.prepare_text(n_frames, expressions, num_image_tokens=self.patch_token,
n_fast_frames=n_fast_frames)
ret = {'conversation': text_dict['conversation']}
return ret
def prepare_text(self, n_frames, expressions, num_image_tokens=256, n_fast_frames=0):
if self.use_fast:
fast_frame_token_str = f'{self.FAST_IMG_START_TOKEN}' \
f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_frames * self.fast_pool_size * self.fast_pool_size}' \
f'{self.FAST_IMG_END_TOKEN}' + '\n'
else:
fast_frame_token_str = ''
frame_token_str = f'{self.IMG_START_TOKEN}' \
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
f'{self.IMG_END_TOKEN}'
questions = []
answers = []
for i, exp in enumerate(expressions):
if self.cur_mode == 'short':
question_template = random.choice(SEG_QUESTIONS_SHORT)
exp = exp.replace("A ", '')
else:
question_template = random.choice(SEG_QUESTIONS)
questions.append(question_template.format(class_name=exp))
answers.append(random.choice(ANSWER_LIST))
qa_list = []
for i, (question, answer) in enumerate(zip(questions, answers)):
if i == 0:
frame_tokens = frame_token_str + '\n'
# frame_tokens = '=' + ' '
frame_tokens = frame_tokens * n_frames
frame_tokens = frame_tokens.strip()
frame_tokens = fast_frame_token_str + frame_tokens
qa_list.append(
{'from': 'human', 'value': frame_tokens + question}
)
else:
qa_list.append(
{'from': 'human', 'value': question}
)
qa_list.append(
{'from': 'gpt', 'value': answer}
)
input = ''
conversation = []
for msg in qa_list:
if msg['from'] == 'human':
input += msg['value']
elif msg['from'] == 'gpt':
conversation.append({'input': input, 'output': msg['value']})
input = ''
else:
raise NotImplementedError
# add system information
conversation[0].update({'system': self._system})
return {'conversation': conversation}
def __getitem__(self, index):
index = index % self.real_len()
video_id = self.video_ids[index]
expression_dict = self.anno_dict[video_id]
object_ids = list(expression_dict['objects'].keys())
video_path = os.path.join(self.sam2_folder, expression_dict['video_path'])
anno_path = os.path.join(self.sam2_folder, expression_dict['anno_path'])
video_frames = get_video_frames(video_path)
if self.use_fast:
# sample fast branch
fast_interval = len(video_frames) / (self.n_fast_images + 1e-4)
sampled_fast_frame_idxs = [min(int(i * fast_interval), len(video_frames) - 1) for i in range(self.n_fast_images)]
fast_video_frames = [video_frames[_idx] for _idx in sampled_fast_frame_idxs]
else:
fast_video_frames = None
video_frames = video_frames[::4]
# mask annotation
with open(anno_path, 'r') as f:
mask_data = json.load(f)
masklents = decode_masklet(mask_data['masklet'])
n_frames = len(masklents)
n_objects = len(object_ids)
# sample object
if n_objects > self.select_number:
selected_indexes = np.random.choice(n_objects, self.select_number)
else:
selected_indexes = np.random.choice(n_objects, self.select_number, replace=True)
selected_object_ids = [object_ids[_idx] for _idx in selected_indexes]
objects_expression_infos = [expression_dict['objects'][_idx] for _idx in selected_object_ids]
_masklents = []
for _mask in masklents:
_mask_selected = []
for _idx in selected_object_ids:
_mask_selected.append(_mask[:, :, int(_idx)])
_mask_selected = np.stack(_mask_selected, axis=2)
_masklents.append(_mask_selected)
masklents = _masklents
# sample video frames
# prepare images, random select k frames
if n_frames > self.sampled_frames + 1:
if self.frame_contiguous_sample and random.random() < 0.5:
# do contiguous sample
selected_start_frame = np.random.choice(n_frames - self.sampled_frames, 1, replace=False)
selected_frame_indexes = [selected_start_frame[0] + _i for _i in range(self.sampled_frames)]
else:
selected_frame_indexes = np.random.choice(n_frames, self.sampled_frames, replace=False)
else:
selected_frame_indexes = np.random.choice(n_frames, self.sampled_frames, replace=True)
selected_frame_indexes.sort()
video_frames = [video_frames[_idx] for _idx in selected_frame_indexes]
masklents = [masklents[_idx] for _idx in selected_frame_indexes]
data_dict = self.dataset_map_fn(objects_expression_infos, len(video_frames), n_fast_frames=self.n_fast_images)
result = self.template_map_fn(data_dict)
data_dict.update(result)
result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
data_dict.update(result)
pixel_values = []
extra_pixel_values = []
for frame in video_frames:
frame = frame[:, :, ::-1]
frame_image = Image.fromarray(frame).convert('RGB')
ori_width, ori_height = frame_image.size
if self.extra_image_processor is not None:
g_image = np.array(frame_image) # for grounding
g_image = self.extra_image_processor.apply_image(g_image)
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
extra_pixel_values.append(g_pixel_values)
frame_image = self.transformer(frame_image)
pixel_values.append(frame_image)
pixel_values = torch.stack(pixel_values, dim=0) # (n_f, 3, h, w)
data_dict['pixel_values'] = pixel_values
if self.extra_image_processor is not None:
data_dict['g_pixel_values'] = extra_pixel_values
# for fast branch
if self.use_fast:
fast_pixel_values = []
for frame_image in fast_video_frames:
frame = frame_image[:, :, ::-1]
frame_image = Image.fromarray(frame).convert('RGB')
ori_width, ori_height = frame_image.size
frame_image = self.transformer(frame_image)
fast_pixel_values.append(frame_image)
fast_pixel_values = torch.stack(fast_pixel_values, dim=0) # (n_f, 3, h, w)
data_dict['fast_pixel_values'] = fast_pixel_values
# process and get masks
masklents = np.stack(masklents, axis=0) # (n_frames, h, w, n_obj)
masklents = torch.from_numpy(masklents).permute(3, 0, 1, 2)
masklents = masklents.flatten(0, 1)
# print('sam2-mask_shape:', masklents.shape)
# print('sam2-pixel_values:', data_dict['pixel_values'].shape)
# print('sam2-g_pixel_values:', len(data_dict['g_pixel_values']), ', ', data_dict['g_pixel_values'][0].shape)
data_dict['masks'] = masklents
data_dict['type'] = 'video'
return data_dict
def visualization_debug(self, data_dict):
save_folder = os.path.join(self.save_folder, 'sample_{}'.format(self.cur_number))
if not os.path.exists(save_folder):
os.mkdir(save_folder)
self.cur_number += 1
# images
show_images = []
pixel_values = data_dict['pixel_values']
save_folder_image = os.path.join(save_folder, 'image')
if not os.path.exists(save_folder_image):
os.mkdir(save_folder_image)
for i_image, image_pixel_value in enumerate(pixel_values):
# print(image_pixel_value.shape)
image_pixel_value[0] = image_pixel_value[0] * 0.2686
image_pixel_value[1] = image_pixel_value[1] * 0.2613
image_pixel_value[2] = image_pixel_value[2] * 0.2757
image_pixel_value[0] = image_pixel_value[0] + 0.4814
image_pixel_value[1] = image_pixel_value[1] + 0.4578
image_pixel_value[2] = image_pixel_value[2] + 0.4082
image_pixel_value = image_pixel_value * 255
image_pixel_value = image_pixel_value.permute(1, 2, 0)
image_pixel_value = image_pixel_value.to(torch.uint8).numpy()
# print(os.path.join(save_folder_image, '{}.jpg'.format(i_image)))
# print(image_pixel_value.shape)
show_images.append(image_pixel_value)
cv2.imwrite(os.path.join(save_folder_image, '{}.jpg'.format(i_image)), image_pixel_value)
# text
input_text = self.tokenizer.decode(data_dict['input_ids'], skip_special_tokens=False)
with open(os.path.join(save_folder, 'text.json'), 'w') as f:
json.dump([input_text], f)
# masks
save_folder_mask = os.path.join(save_folder, 'mask')
if not os.path.exists(save_folder_mask):
os.mkdir(save_folder_mask)
n_frames = len(pixel_values)
masks = data_dict['masks']
_, h, w = masks.shape
masks = masks.reshape(-1, n_frames, h, w)
for i_obj, obj_masks in enumerate(masks):
save_folder_mask_obj_folder = os.path.join(save_folder_mask, 'obj_{}'.format(i_obj))
if not os.path.exists(save_folder_mask_obj_folder):
os.mkdir(save_folder_mask_obj_folder)
for i_frame, f_mask in enumerate(obj_masks):
f_mask = f_mask.numpy()
f_mask = f_mask * 255
f_mask = np.stack([f_mask * 1, f_mask * 0, f_mask * 0], axis=2)
f_mask = show_images[i_frame] * 0.3 + 0.7 * f_mask
f_mask = f_mask.astype(np.uint8)
cv2.imwrite(os.path.join(save_folder_mask_obj_folder, '{}.png'.format(i_frame)), f_mask)
return
def get_video_frames(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print("Error: Cannot open video file.")
return
frames = []
frame_id = 0
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
frame_id += 1
cap.release()
return frames
def images_to_video(frames, video_name, fps=6):
height, width, layers = frames[0].shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video = cv2.VideoWriter(video_name, fourcc, fps, (width, height))
for frame in frames:
video.write(frame)
# cv2.destroyAllWindows()
video.release()
return
def decode_masklet(masklet):
masks = []
for _rle in masklet:
mask = maskUtils.decode(_rle)
masks.append(mask)
return masks
def draw_mask(image, mask):
obj_mask = mask * 255
obj_mask = np.stack([obj_mask * 1, obj_mask * 0, obj_mask * 0], axis=2)
obj_mask = obj_mask * 0.5 + copy.deepcopy(image) * 0.5
obj_mask = obj_mask.astype(np.uint8)
return obj_mask
def add_mask2images(frames, masklets):
show_videos = []
for i_frames, (frame, masks) in enumerate(zip(frames, masklets)):
if i_frames == 0:
n_obj = masks.shape[-1]
for i_obj in range(n_obj):
show_videos.append([])
n_obj = masks.shape[-1]
for i_obj in range(n_obj):
show_videos[i_obj].append(draw_mask(copy.deepcopy(frame), masks[:, :, i_obj]))
return show_videos
\ No newline at end of file
import logging
import os
from typing import Literal
import torch
from datasets import Dataset as HFDataset
from datasets import DatasetDict
from mmengine import print_log
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
from xtuner.registry import BUILDER
from xtuner.dataset.huggingface import build_origin_dataset
import copy
from .encode_fn import video_lisa_encode_fn
import json
import random
import pycocotools.mask as maskUtils
import cv2
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
SEG_QUESTIONS = [
"Can you segment the {class_name} in this image?",
"Please segment {class_name} in this image.",
"What is {class_name} in this image? Please respond with segmentation mask.",
"What is {class_name} in this image? Please output segmentation mask.",
"Can you segment the {class_name} in this image",
"Please segment {class_name} in this image",
"What is {class_name} in this image? Please respond with segmentation mask",
"What is {class_name} in this image? Please output segmentation mask",
"Could you provide a segmentation mask for the {class_name} in this image?",
"Please identify and segment the {class_name} in this image.",
"Where is the {class_name} in this picture? Please respond with a segmentation mask.",
"Can you highlight the {class_name} in this image with a segmentation mask?",
"Could you provide a segmentation mask for the {class_name} in this image",
"Please identify and segment the {class_name} in this image",
"Where is the {class_name} in this picture? Please respond with a segmentation mask",
"Can you highlight the {class_name} in this image with a segmentation mask",
]
ANSWER_LIST = [
"It is [SEG].",
"Sure, [SEG].",
"Sure, it is [SEG].",
"Sure, the segmentation result is [SEG].",
"[SEG].",
]
class VideoReVOSDataset(Dataset):
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
FAST_IMG_CONTEXT_TOKEN = '<FAST_IMG_CONTEXT>'
FAST_IMG_START_TOKEN = '<fast_img>'
FAST_IMG_END_TOKEN = '</fast_img>'
def __init__(self,
image_folder,
expression_file,
mask_file,
extra_image_processor=None,
tokenizer=None,
select_number=5,
sampled_frames=10,
offline_processed_text_folder=None,
template_map_fn=None,
max_length=2048,
lazy=True,
repeats=1,
special_tokens=None,
frame_contiguous_sample=False,
use_fast=False,
arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
preprocessor=None,
# only work if use_fast = True
n_fast_images=50,
fast_pool_size=4,
fast_token_after_question=False,
):
assert lazy is True
self.tokenizer = BUILDER.build(tokenizer)
self.select_number = select_number
self.sampled_frames = sampled_frames
assert offline_processed_text_folder or (expression_file and tokenizer)
self.lazy = lazy
self.max_length = max_length
self.template_map_fn = template_map_fn
if isinstance(self.template_map_fn, dict) and self.lazy:
_type = self.template_map_fn['type']
del self.template_map_fn['type']
self.template_map_fn = _type(**self.template_map_fn)
if offline_processed_text_folder and expression_file:
print_log(
'Both `offline_processed_text_folder` and '
'`data_path` are set, and we load dataset from'
'`offline_processed_text_folder` '
f'({offline_processed_text_folder})',
logger='current',
level=logging.WARNING)
self.arch_type = arch_type
if self.arch_type == 'qwen':
self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
self.IMG_START_TOKEN = '<|vision_start|>'
self.IMG_END_TOKEN = '<|vision_end|>'
elif self.arch_type == 'llava':
self.IMG_CONTEXT_TOKEN = '<image>'
self.IMG_START_TOKEN = ''
self.IMG_END_TOKEN = ''
if offline_processed_text_folder is not None:
raise NotImplementedError
else:
vid2metaid, metas, mask_dict = self.json_file_preprocess(expression_file, mask_file)
self.vid2metaid = vid2metaid
self.videos = list(self.vid2metaid.keys())
self.mask_dict = mask_dict
self.json_datas = metas
json_datas = metas
json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
if self.lazy:
self.text_data = build_origin_dataset(json_data, 'train')
else:
raise NotImplementedError
self.image_folder = image_folder
if extra_image_processor is not None:
self.extra_image_processor = BUILDER.build(extra_image_processor)
self.down_ratio = 1
self.repeats = repeats
self._system = ''
self.downsample_ratio = 0.5
if self.arch_type == 'llava':
self.downsample_ratio = 1
self.image_size = 448
if self.arch_type == 'llava':
self.image_size = 336
patch_size = 14
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
if self.arch_type == 'qwen':
self.patch_token = 1
if preprocessor is None:
self.transformer = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
])
self.preprocessor = None
else:
self.transformer = None
self.preprocessor = BUILDER.build(preprocessor)
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.use_fast = use_fast
self.n_fast_images = n_fast_images
self.fast_pool_size = fast_pool_size
self.frame_contiguous_sample = frame_contiguous_sample
# for visualization debug
self.save_folder = './work_dirs/video_debug/'
self.cur_number = 0
# exist_thr
self.exist_thr = 8
self.fast_token_after_question = fast_token_after_question
if self.fast_token_after_question:
assert self.use_fast
print("Video res dataset, include {} items.".format(len(self.vid2metaid)))
def __len__(self):
return len(self.vid2metaid) * self.repeats
@property
def modality_length(self):
length_list = []
for data_dict in self.vid2metaid:
cur_len = 10000
length_list.append(cur_len)
return length_list
def real_len(self):
return len(self.vid2metaid)
def json_file_preprocess(self, expression_file, mask_file):
# prepare expression annotation files
with open(expression_file, 'r') as f:
expression_datas = json.load(f)['videos']
metas = []
anno_count = 0 # serve as anno_id
vid2metaid = {}
for vid_name in expression_datas:
vid_express_data = expression_datas[vid_name]
vid_frames = sorted(vid_express_data['frames'])
vid_len = len(vid_frames)
exp_id_list = sorted(list(vid_express_data['expressions'].keys()))
for exp_id in exp_id_list:
exp_dict = vid_express_data['expressions'][exp_id]
meta = {}
meta['video'] = vid_name
meta['exp'] = exp_dict['exp'] # str
meta['mask_anno_id'] = exp_dict['anno_id']
if 'obj_id' in exp_dict.keys():
meta['obj_id'] = exp_dict['obj_id']
else:
meta['obj_id'] = [0, ] # Ref-Youtube-VOS only has one object per expression
meta['anno_id'] = [str(anno_count), ]
anno_count += 1
meta['frames'] = vid_frames
meta['exp_id'] = exp_id
meta['length'] = vid_len
metas.append(meta)
if vid_name not in vid2metaid.keys():
vid2metaid[vid_name] = []
vid2metaid[vid_name].append(len(metas) - 1)
# process mask annotation files
with open(mask_file, 'rb') as f:
mask_dict = json.load(f)
return vid2metaid, metas, mask_dict
def create_img_to_refs_mapping(self, refs_train):
img2refs = {}
for ref in refs_train:
img2refs[ref["image_id"]] = img2refs.get(ref["image_id"], []) + [ref, ]
return img2refs
def decode_mask(self, video_masks, image_size):
ret_masks = []
for object_masks in video_masks:
# None object
if len(object_masks) == 0:
if len(ret_masks) != 0:
_object_masks = ret_masks[0] * 0
else:
_object_masks = np.zeros(
(self.sampled_frames, image_size[0], image_size[1]), dtype=np.uint8)
else:
_object_masks = []
for i_frame in range(len(object_masks[0])):
_mask = np.zeros(image_size, dtype=np.uint8)
for i_anno in range(len(object_masks)):
if object_masks[i_anno][i_frame] is None:
continue
m = maskUtils.decode(object_masks[i_anno][i_frame])
if m.ndim == 3:
m = m.sum(axis=2).astype(np.uint8)
else:
m = m.astype(np.uint8)
_mask = _mask | m
_object_masks.append(_mask)
_object_masks = np.stack(_object_masks, axis=0)
# if self.pad_image_to_square:
# _object_masks = expand2square_mask(_object_masks)
ret_masks.append(_object_masks)
_shape = ret_masks[0].shape
for item in ret_masks:
if item.shape != _shape:
print([_ret_mask.shape for _ret_mask in ret_masks])
return None
ret_masks = np.stack(ret_masks, axis=0) # (n_obj, n_frames, h, w)
ret_masks = torch.from_numpy(ret_masks)
# ret_masks = F.interpolate(ret_masks, size=(self.image_size // self.down_ratio,
# self.image_size // self.down_ratio), mode='nearest')
ret_masks = ret_masks.flatten(0, 1)
return ret_masks
def dataset_map_fn(self, data_dict, select_k=5):
images = []
len_frames = len(data_dict[0]['frames'])
for objet_info in data_dict:
assert len_frames == len(objet_info['frames'])
# prepare images, random select k frames
if len_frames > select_k + 1:
if self.frame_contiguous_sample and random.random() < 0.5:
# do contiguous sample
selected_start_frame = np.random.choice(len_frames - select_k, 1, replace=False)
selected_frame_indexes = [selected_start_frame[0] + _i for _i in range(select_k)]
else:
selected_frame_indexes = np.random.choice(len_frames, select_k, replace=False)
else:
selected_frame_indexes = np.random.choice(len_frames, select_k, replace=True)
selected_frame_indexes.sort()
if self.use_fast:
# sample fast branch
fast_interval = len_frames / (self.n_fast_images + 1e-4)
sampled_fast_frame_idxs = [min(int(i * fast_interval), len_frames - 1) for i in range(self.n_fast_images)]
fast_video_frames = []
for selected_frame_index in sampled_fast_frame_idxs:
frame_id = data_dict[0]['frames'][selected_frame_index]
fast_video_frames.append(os.path.join(data_dict[0]['video'], frame_id + '.jpg'))
else:
fast_video_frames = None
sampled_fast_frame_idxs = None
for selected_frame_index in selected_frame_indexes:
frame_id = data_dict[0]['frames'][selected_frame_index]
images.append(os.path.join(data_dict[0]['video'], frame_id + '.jpg'))
# prepare text
expressions = [object_info['exp'] for object_info in data_dict]
if self.use_fast:
text_dict = self.prepare_text(select_k, expressions, num_image_tokens=self.patch_token,
n_fast_images=len(fast_video_frames),)
else:
text_dict = self.prepare_text(select_k, expressions, num_image_tokens=self.patch_token)
# prepare masks
video_masks = []
for object_info in data_dict:
anno_ids = object_info['mask_anno_id']
# print('anno_ids: ', anno_ids)
obj_masks = []
for anno_id in anno_ids:
anno_id = str(anno_id)
frames_masks = self.mask_dict[anno_id]
frames_masks_ = []
for frame_idx in selected_frame_indexes:
frames_masks_.append(copy.deepcopy(frames_masks[frame_idx]))
obj_masks.append(frames_masks_)
video_masks.append(obj_masks)
if self.use_fast:
fast_video_masks = []
assert sampled_fast_frame_idxs is not None
for object_info in data_dict:
anno_ids = object_info['mask_anno_id']
obj_masks = []
for anno_id in anno_ids:
anno_id = str(anno_id)
frames_masks = self.mask_dict[anno_id]
frames_masks_ = []
for frame_idx in sampled_fast_frame_idxs:
frames_masks_.append(copy.deepcopy(frames_masks[frame_idx]))
obj_masks.append(frames_masks_)
fast_video_masks.append(obj_masks)
else:
fast_video_masks = None
ret = {'images': images, 'video_masks': video_masks, 'conversation': text_dict['conversation'],
'fast_images': fast_video_frames, 'fast_video_masks': fast_video_masks}
return ret
def prepare_text(self, n_frames, expressions, num_image_tokens=256, n_fast_images=50):
if self.use_fast and not self.fast_token_after_question:
fast_frame_token_str = f'{self.FAST_IMG_START_TOKEN}' \
f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_images * self.fast_pool_size * self.fast_pool_size}' \
f'{self.FAST_IMG_END_TOKEN}' + '\n'
else:
fast_frame_token_str = ''
frame_token_str = f'{self.IMG_START_TOKEN}' \
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
f'{self.IMG_END_TOKEN}'
if self.fast_token_after_question:
assert self.use_fast
after_question_str = f'{self.FAST_IMG_START_TOKEN}' \
f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_images * self.fast_pool_size * self.fast_pool_size}' \
f'{self.FAST_IMG_END_TOKEN}'
else:
after_question_str = ''
questions = []
answers = []
for i, exp in enumerate(expressions):
# the exp is a question
if '?' in exp:
questions.append(exp)
else:
exp = exp.replace('.', '').strip()
question_template = random.choice(SEG_QUESTIONS)
questions.append(question_template.format(class_name=exp.lower()))
answers.append(random.choice(ANSWER_LIST))
qa_list = []
for i, (question, answer) in enumerate(zip(questions, answers)):
if i == 0:
frame_tokens = frame_token_str + '\n'
# frame_tokens = '=' + ' '
frame_tokens = frame_tokens * n_frames
frame_tokens = frame_tokens.strip()
frame_tokens = fast_frame_token_str + frame_tokens
qa_list.append(
{'from': 'human', 'value': frame_tokens + question + after_question_str}
)
else:
qa_list.append(
{'from': 'human', 'value': question + after_question_str}
)
qa_list.append(
{'from': 'gpt', 'value': answer}
)
input = ''
conversation = []
for msg in qa_list:
if msg['from'] == 'human':
input += msg['value']
elif msg['from'] == 'gpt':
conversation.append({'input': input, 'output': msg['value']})
input = ''
else:
raise NotImplementedError
# add system information
conversation[0].update({'system': self._system})
return {'conversation': conversation}
def __getitem__(self, index):
index = index % self.real_len()
selected_video_objects = self.vid2metaid[self.videos[index]]
video_objects_infos = [copy.deepcopy(self.text_data[idx]) for idx in selected_video_objects]
if len(video_objects_infos) > self.select_number:
selected_indexes = np.random.choice(len(video_objects_infos), self.select_number)
video_objects_infos = [video_objects_infos[_idx] for _idx in selected_indexes]
else:
selected_indexes = np.random.choice(len(video_objects_infos), self.select_number, replace=True)
video_objects_infos = [video_objects_infos[_idx] for _idx in selected_indexes]
data_dict = self.dataset_map_fn(video_objects_infos, select_k=self.sampled_frames)
assert 'images' in data_dict.keys()
pixel_values = []
extra_pixel_values = []
num_video_tokens = None
num_frame_tokens = None
if data_dict.get('images', None) is not None:
frames_files = data_dict['images']
frames_files = [os.path.join(self.image_folder, frame_file) for frame_file in frames_files]
for frame_path in frames_files:
frame_image = Image.open(frame_path).convert('RGB')
ori_width, ori_height = frame_image.size
if self.extra_image_processor is not None:
g_image = np.array(frame_image) # for grounding
g_image = self.extra_image_processor.apply_image(g_image)
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
extra_pixel_values.append(g_pixel_values)
if self.preprocessor is not None:
pass
else:
frame_image = self.transformer(frame_image)
pixel_values.append(frame_image)
if self.preprocessor is not None:
if self.arch_type == 'qwen':
_data_dict = self.preprocessor(pixel_values, do_resize=True, size=(self.image_size, self.image_size))
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
_data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
num_frame_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
num_frames = _data_dict['image_grid_thw'].shape[0]
num_video_tokens = num_frame_tokens * num_frames
elif self.arch_type == 'llava':
_data_dict = self.preprocessor(pixel_values, do_resize=True, size=(self.image_size, self.image_size))
_data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
else:
raise NotImplementedError
data_dict.update(_data_dict)
else:
pixel_values = torch.stack(pixel_values, dim=0) # (n_f, 3, h, w)
data_dict['pixel_values'] = pixel_values
if self.extra_image_processor is not None:
data_dict['g_pixel_values'] = extra_pixel_values
# process and get masks
masks = self.decode_mask(data_dict['video_masks'], image_size=(ori_height, ori_width))
if masks is None:
return self.__getitem__(random.randint(0, self.real_len()))
data_dict['masks'] = masks
else:
data_dict['pixel_values'] = torch.zeros(0, 3, self.image_size, self.image_size)
data_dict['masks'] = None
if num_video_tokens is not None:
assert self.patch_token == 1
input_str = data_dict['conversation'][0]['input']
input_str = input_str.replace(self.IMG_CONTEXT_TOKEN, self.IMG_CONTEXT_TOKEN * num_frame_tokens)
assert input_str.count(self.IMG_CONTEXT_TOKEN) == num_video_tokens
data_dict['conversation'][0]['input'] = input_str
result = self.template_map_fn(data_dict)
data_dict.update(result)
result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length)
data_dict.update(result)
# for fast branch
if self.use_fast:
fast_pixel_values = []
frames_files = data_dict['fast_images']
frames_files = [os.path.join(self.image_folder, frame_file) for frame_file in frames_files]
for frame_path in frames_files:
frame_image = Image.open(frame_path).convert('RGB')
ori_width, ori_height = frame_image.size
frame_image = self.transformer(frame_image)
fast_pixel_values.append(frame_image)
fast_pixel_values = torch.stack(fast_pixel_values, dim=0) # (n_f, 3, h, w)
data_dict['fast_pixel_values'] = fast_pixel_values
# process and get masks
masks = self.decode_mask(data_dict['fast_video_masks'], image_size=(ori_height, ori_width))
if masks is None:
return self.__getitem__(random.randint(0, self.real_len()))
data_dict['fast_exists'] = masks.to(dtype=torch.int).sum(dim=(-2, -1)).ge(self.exist_thr).unsqueeze(-1)
del data_dict['fast_video_masks']
data_dict['type'] = 'video'
return data_dict
def visualization_debug(self, data_dict):
save_folder = os.path.join(self.save_folder, 'sample_{}'.format(self.cur_number))
if not os.path.exists(save_folder):
os.mkdir(save_folder)
self.cur_number += 1
# images
show_images = []
pixel_values = data_dict['pixel_values']
save_folder_image = os.path.join(save_folder, 'image')
if not os.path.exists(save_folder_image):
os.mkdir(save_folder_image)
for i_image, image_pixel_value in enumerate(pixel_values):
# print(image_pixel_value.shape)
image_pixel_value[0] = image_pixel_value[0] * 0.2686
image_pixel_value[1] = image_pixel_value[1] * 0.2613
image_pixel_value[2] = image_pixel_value[2] * 0.2757
image_pixel_value[0] = image_pixel_value[0] + 0.4814
image_pixel_value[1] = image_pixel_value[1] + 0.4578
image_pixel_value[2] = image_pixel_value[2] + 0.4082
image_pixel_value = image_pixel_value * 255
image_pixel_value = image_pixel_value.permute(1, 2, 0)
image_pixel_value = image_pixel_value.to(torch.uint8).numpy()
# print(os.path.join(save_folder_image, '{}.jpg'.format(i_image)))
# print(image_pixel_value.shape)
show_images.append(image_pixel_value)
cv2.imwrite(os.path.join(save_folder_image, '{}.jpg'.format(i_image)), image_pixel_value)
# text
input_text = self.tokenizer.decode(data_dict['input_ids'], skip_special_tokens=False)
with open(os.path.join(save_folder, 'text.json'), 'w') as f:
json.dump([input_text], f)
# masks
save_folder_mask = os.path.join(save_folder, 'mask')
if not os.path.exists(save_folder_mask):
os.mkdir(save_folder_mask)
n_frames = len(pixel_values)
masks = data_dict['masks']
_, h, w = masks.shape
masks = masks.reshape(-1, n_frames, h, w)
for i_obj, obj_masks in enumerate(masks):
save_folder_mask_obj_folder = os.path.join(save_folder_mask, 'obj_{}'.format(i_obj))
if not os.path.exists(save_folder_mask_obj_folder):
os.mkdir(save_folder_mask_obj_folder)
for i_frame, f_mask in enumerate(obj_masks):
f_mask = f_mask.numpy()
f_mask = f_mask * 255
f_mask = np.stack([f_mask * 1, f_mask * 0, f_mask * 0], axis=2)
f_mask = show_images[i_frame] * 0.3 + 0.7 * f_mask
f_mask = f_mask.astype(np.uint8)
cv2.imwrite(os.path.join(save_folder_mask_obj_folder, '{}.png'.format(i_frame)), f_mask)
return
import copy
import random
import glob
import json
import logging
import os
from typing import Literal
import torch
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from pycocotools.coco import COCO
from pycocotools import mask as mask_utils
from xtuner.registry import BUILDER
from xtuner.utils import IGNORE_INDEX
from xtuner.dataset.utils import encode_fn
from xtuner.dataset.map_fns import llava_map_fn
from projects.glamm.datasets.utils.utils import expand2square
from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST
from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from third_parts.mmdet.datasets.refcoco import RefCocoDataset
from .utils import dynamic_preprocess
class ReferSegmDataset(RefCocoDataset):
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def __init__(self,
data_root,
ann_file=None,
split_file=None,
special_tokens=None,
prompt_template=None,
extra_image_processor=None,
data_prefix=dict(img_path='train2014/'),
tokenizer=None,
max_length=2048,
num_classes_per_sample=3,
single_image_mode=False,
arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
preprocessor=None,
**kwargs):
super().__init__(
data_root=data_root,
data_prefix=data_prefix,
pipeline=None,
ann_file=ann_file,
split_file=split_file,
**kwargs,
)
self.begin_str = f'{DEFAULT_IMAGE_TOKEN}\n'
if extra_image_processor is not None:
self.extra_image_processor = BUILDER.build(extra_image_processor)
self.arch_type = arch_type
if self.arch_type == 'qwen':
self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
self.IMG_START_TOKEN = '<|vision_start|>'
self.IMG_END_TOKEN = '<|vision_end|>'
elif self.arch_type == 'llava':
self.IMG_CONTEXT_TOKEN = '<image>'
self.IMG_START_TOKEN = ''
self.IMG_END_TOKEN = ''
self.tokenizer = BUILDER.build(tokenizer)
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.image_folder = data_root
self.template = prompt_template
self.max_length = max_length
if self.arch_type == 'intern_vl':
# self._system = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。'
self._system = ''
self.template['INSTRUCTION'] = '<|user|>\n{input}<|end|><|assistant|>\n'
elif self.arch_type == 'qwen':
self._system = ''
elif self.arch_type == 'llava':
self._system = ''
self.num_classes_per_sample = num_classes_per_sample
self.min_dynamic_patch = 1
self.max_dynamic_patch = 12
self.downsample_ratio = 0.5
if self.arch_type == 'llava':
self.downsample_ratio = 1
self.image_size = 448
if self.arch_type == 'llava':
self.image_size = 336
self.use_thumbnail = True
patch_size = 14
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
if preprocessor is None:
self.transformer = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
])
self.preprocessor = None
else:
self.transformer = None
self.preprocessor = BUILDER.build(preprocessor)
self.arch_type = arch_type
self.single_image_mode = single_image_mode
self._max_refetch = 1000
print("Image RES dataset, include {} items.".format(len(self)))
@property
def modality_length(self):
import pickle
length_list = []
for idx in range(len(self)):
length_list.append(100)
return length_list
def _parse_annotations(self, ann_info):
image_path = ann_info['img_path']
image = Image.open(image_path).convert('RGB')
width, height = image.size
masks, phrases = [], []
instances, text = ann_info['instances'], ann_info['text']
# index = np.random.choice(range(len(instances)), min(
# len(instances), self.num_classes_per_sample))
index = np.random.choice(range(len(instances)), self.num_classes_per_sample, replace=True)
for idx in index:
inst = instances[idx]
phrase = text[idx].lower()
if '.' == phrase[-1]:
phrase = phrase[:-1]
phrases.append(phrase)
binary_mask = np.zeros((height, width), dtype=np.uint8)
for seg in inst["mask"]:
rles = mask_utils.frPyObjects([seg], height, width)
m = mask_utils.decode(rles)
m = m.astype(np.uint8)
binary_mask += m.squeeze()
masks.append(binary_mask)
conversation = []
for i, phrase in enumerate(phrases):
question = random.choice(SEG_QUESTIONS).format(class_name=phrase)
if i == 0:
question = self.begin_str + question
conversation.append({'from': 'human', 'value': question})
conversation.append({'from': 'gpt', 'value': random.choice(ANSWER_LIST)})
masks = torch.stack([torch.from_numpy(mask) for mask in masks], dim=0)
ann_info.update({
'masks': masks,
'conversations': conversation,
'image': image_path
})
return ann_info
def prepare_data(self, index):
data_dict = super().prepare_data(index)
data_dict = self._parse_annotations(data_dict)
if data_dict is None:
return None
out_data_dict = {}
if 'masks' in data_dict:
out_data_dict['masks'] = data_dict['masks']
if data_dict.get('image', None) is not None:
image_file = data_dict['image']
try:
image = Image.open(image_file).convert('RGB')
except Exception as e:
print(f'Error: {e}', flush=True)
print_log(f'Error: {e}', logger='current')
return None
if hasattr(self, 'extra_image_processor'):
g_image = np.array(image) # for grounding
g_image = self.extra_image_processor.apply_image(g_image)
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
out_data_dict['g_pixel_values'] = g_pixel_values
if self.single_image_mode:
images = [image]
else:
images = dynamic_preprocess(image, self.min_dynamic_patch,
self.max_dynamic_patch,
self.image_size, self.use_thumbnail)
if self.preprocessor is not None:
if self.arch_type == 'qwen':
_data_dict = self.preprocessor(images, do_resize=True)
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
_data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
num_image_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
elif self.arch_type == 'llava':
_data_dict = self.preprocessor(images, do_resize=True, size=(self.image_size, self.image_size))
_data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
num_image_tokens = _data_dict['pixel_values'].shape[0] * self.patch_token
else:
raise NotImplementedError
out_data_dict.update(_data_dict)
else:
pixel_values = [self.transformer(image) for image in images]
pixel_values = torch.stack(pixel_values)
out_data_dict['pixel_values'] = pixel_values
num_image_tokens = pixel_values.shape[0] * self.patch_token
image_token_str = f'{self.IMG_START_TOKEN}' \
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
f'{self.IMG_END_TOKEN}'
token_dict = self.get_inputid_labels(data_dict['conversations'], image_token_str)
out_data_dict.update(token_dict)
else:
token_dict = self.get_inputid_labels(data_dict['conversations'], None)
out_data_dict.update(token_dict)
out_data_dict['pixel_values'] = torch.zeros(1, 3, self.image_size, self.image_size)
return out_data_dict
def get_inputid_labels(self, conversations, image_token_str) -> dict:
input = ''
out_conversation = []
while conversations and conversations[0]['from'] == 'gpt':
# Skip the first one if it is from gpt
conversations = conversations[1:]
for msg in conversations:
if msg['from'] == 'human':
if image_token_str is None and '<image>' in msg['value']:
msg['value'] = msg['value'].replace('<image>', '')
if '<image>' in msg['value']:
msg['value'] = msg['value'].replace('<image>', image_token_str).strip()
input += msg['value'].strip()
elif msg['from'] == 'gpt':
out_conversation.append({
'input': input,
'output': msg['value'].strip()
})
input = ''
else:
raise NotImplementedError
input_ids, labels = [], []
for i, single_turn_conversation in enumerate(out_conversation):
input = single_turn_conversation.get('input', '')
if input is None:
input = ''
input_text = self.template.INSTRUCTION.format(
input=input, round=i + 1)
if i == 0:
if self._system != '' and self._system is not None:
system = self.template.SYSTEM.format(system=self._system)
input_text = system + input_text
input_encode = self.tokenizer.encode(
input_text, add_special_tokens=True)
else:
input_encode = self.tokenizer.encode(
input_text, add_special_tokens=False)
input_ids += input_encode
labels += [IGNORE_INDEX] * len(input_encode)
output_text = single_turn_conversation.get('output', '')
if self.template.get('SUFFIX', None):
output_text += self.template.SUFFIX
output_encode = self.tokenizer.encode(
output_text, add_special_tokens=False)
input_ids += output_encode
labels += copy.deepcopy(output_encode)
if len(input_ids) > self.max_length:
input_ids = input_ids[:self.max_length]
labels = labels[:self.max_length]
# print('len_ids: ', len(input_ids))
return {'input_ids': input_ids, 'labels': labels}
def __getitem__(self, index):
for _ in range(self._max_refetch + 1):
data = self.prepare_data(index)
# Broken images may cause the returned data to be None
if data is None:
index = self._rand_another()
continue
return data
if __name__ == '__main__':
from transformers import CLIPImageProcessor, AutoTokenizer
from third_parts.segment_anything.utils.transforms import ResizeLongestSide
pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=llm_name_or_path)
image_processor = dict(
type=CLIPImageProcessor.from_pretrained,
pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
extra_image_processor = dict(
type=ResizeLongestSide,
target_length=1024,
)
from xtuner.utils.templates import PROMPT_TEMPLATE
prompt_template = PROMPT_TEMPLATE.vicuna
from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
dataset = ReferSegmDataset(
tokenizer=tokenizer,
special_tokens=['[SEG]'],
extra_image_processor=extra_image_processor,
prompt_template=prompt_template,
data_root='data/coco/',
data_prefix=dict(img_path='train2014/'),
ann_file='refcoco+/instances.json',
split_file='refcoco+/refs(unc).p',
)
for i in range(1000):
dataset[i]
\ No newline at end of file
from .ReVOS_Dataset import VideoReVOSDataset
import json
import pickle
class VideoRefYoutubeVOSDataset(VideoReVOSDataset):
def json_file_preprocess(self, expression_file, mask_file):
# prepare expression annotation files
with open(expression_file, 'r') as f:
expression_datas = json.load(f)['videos']
metas = []
anno_count = 0 # serve as anno_id
vid2metaid = {}
for vid_name in expression_datas:
vid_express_data = expression_datas[vid_name]
vid_frames = sorted(vid_express_data['frames'])
vid_len = len(vid_frames)
exp_id_list = sorted(list(vid_express_data['expressions'].keys()))
for exp_id in exp_id_list:
exp_dict = vid_express_data['expressions'][exp_id]
meta = {}
meta['video'] = vid_name
meta['exp'] = exp_dict['exp'] # str
meta['mask_anno_id'] = [str(anno_count), ]
if 'obj_id' in exp_dict.keys():
meta['obj_id'] = exp_dict['obj_id']
else:
meta['obj_id'] = [0, ] # Ref-Youtube-VOS only has one object per expression
meta['anno_id'] = [str(anno_count), ]
anno_count += 1
meta['frames'] = vid_frames
meta['exp_id'] = exp_id
meta['length'] = vid_len
metas.append(meta)
if vid_name not in vid2metaid.keys():
vid2metaid[vid_name] = []
vid2metaid[vid_name].append(len(metas) - 1)
# process mask annotation files
with open(mask_file, 'rb') as f:
mask_dict = pickle.load(f)
return vid2metaid, metas, mask_dict
from .collect_fns import video_lisa_collate_fn
from .MeVIS_Dataset import VideoMeVISDataset
from .ReVOS_Dataset import VideoReVOSDataset
from .RefYoutubeVOS_Dataset import VideoRefYoutubeVOSDataset
from .encode_fn import video_lisa_encode_fn
from .RefCOCO_Dataset import ReferSegmDataset
from .ReSAM2_Dataset import VideoSAM2Dataset
from .vqa_dataset import LLaVADataset, InfinityMMDataset
from .GCG_Dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset
from .Grand_Dataset import GranDDataset
from .Osprey_Dataset import OspreyDataset, OspreyDescriptionDataset, OspreyShortDescriptionDataset
from .ChatUniVi_Dataset import VideoChatUniViDataset
from typing import Dict, Sequence
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
pad_for_sequence_parallel)
from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
def video_lisa_collate_fn(instances: Sequence[Dict],
pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
return_hf_format: bool = False,
use_varlen_attn: bool = False):
seq_parallel_world_size = get_sequence_parallel_world_size()
input_ids, labels = [], []
has_image = any(inst.get('pixel_values') is not None for inst in instances)
has_pe = any(inst.get('image_grid_thw', None) is not None for inst in instances)
has_fast_image = any(inst.get('fast_pixel_values', None) is not None for inst in instances)
has_grounding_image = any(inst.get('g_pixel_values') is not None for inst in instances)
has_mask = any(inst.get('masks') is not None for inst in instances)
has_bboxes = any(inst.get('bboxes') is not None for inst in instances)
has_points = any(inst.get('points') is not None for inst in instances)
has_fast_exists = any(inst.get('fast_exists') is not None for inst in instances)
has_vp = any(inst.get('vp_overall_mask') is not None for inst in instances)
has_prompt_mask = any(inst.get('prompt_masks') is not None for inst in instances)
if use_varlen_attn:
position_ids, cumulative_len = [], []
assert len(instances) == 1, (
f'If utilizing varlen attention, the batch size should be'
f' set to 1, but got {len(instances)}')
assert not has_image, 'Currently, it is not configured to '
'accommodate the use of varlen Attention in multimodal training'
if has_image:
pixel_values = []
frames_per_batch = []
image_grid_thw = []
if has_grounding_image:
grounding_pixel_values = []
if has_mask:
object_masks = []
if has_bboxes:
object_bboxes = []
if has_points:
prompt_points = []
if has_fast_image:
fast_pixel_values = []
if has_fast_exists:
fast_exists = []
if has_vp:
vp_overall_mask = []
else:
vp_overall_mask = None
if has_prompt_mask:
prompt_masks = []
else:
prompt_masks = None
for example in instances:
input_ids.append(torch.LongTensor(example['input_ids']))
labels.append(torch.LongTensor(example['labels']))
if use_varlen_attn:
cumulative_len.append(torch.IntTensor(example['cumulative_len']))
position_ids.append(torch.LongTensor(example['position_ids']))
if has_image:
pixel_values.append(example['pixel_values'])
if has_pe:
image_grid_thw.append(example['image_grid_thw'])
if has_vp:
if 'vp_overall_mask' in example.keys() and example['vp_overall_mask'] is not None:
vp_overall_mask.append(example['vp_overall_mask'])
else:
vp_overall_mask.append(torch.Tensor([False] * len(pixel_values[-1])))
if has_fast_image:
if 'fast_pixel_values' in example.keys() and example['fast_pixel_values'] is not None:
fast_pixel_values.append(example['fast_pixel_values'])
if has_fast_exists:
if 'fast_exists' in example.keys() and example['fast_exists'] is not None:
fast_exists.append(example['fast_exists'])
if has_grounding_image and 'g_pixel_values' in example.keys():
if isinstance(example['g_pixel_values'], list):
grounding_pixel_values += example['g_pixel_values']
frames_per_batch.append(len(example['g_pixel_values']))
else:
grounding_pixel_values.append(example['g_pixel_values'])
frames_per_batch.append(1)
if has_mask:
if 'masks' in example.keys() and example['masks'] is not None:
if isinstance(example['masks'], list):
if isinstance(example['masks'][0], np.ndarray):
_masks = np.stack(example['masks'], axis=0)
_masks = torch.from_numpy(_masks)
object_masks.append(_masks)
else:
object_masks.append(torch.stack(example['masks'], dim=0))
else:
object_masks.append(example['masks'])
if has_bboxes:
if 'bboxes' in example.keys() and example['bboxes'] is not None:
object_bboxes.append(example['bboxes'])
if has_points:
if 'points' in example.keys() and example['points'] is not None:
prompt_points.append(example['points'])
if has_prompt_mask:
if 'prompt_masks' in example.keys():
prompt_masks.append(example['prompt_masks'])
ori_length = [len(ids) for ids in input_ids]
if len(instances) > 1:
input_ids = pad_sequence(
input_ids, batch_first=True, padding_value=pad_index)
labels = pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX)
else:
input_ids = torch.stack(input_ids)
labels = torch.stack(labels)
if use_varlen_attn:
assert input_ids.size(1) % seq_parallel_world_size == 0
attention_mask = None
position_ids = torch.stack(position_ids, dim=0)
else:
# Some tokenizers have the same eos token and pad token, so input_ids
# cannot be masked directly based on the pad token id.
attention_mask = torch.zeros_like(input_ids).bool()
for i, length in enumerate(ori_length):
attention_mask[i, :length] = True
bs, seq_len = input_ids.shape
position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
if seq_parallel_world_size > 1:
input_ids = pad_for_sequence_parallel(input_ids, pad_index)
labels = pad_for_sequence_parallel(labels, IGNORE_INDEX)
position_ids = pad_for_sequence_parallel(position_ids, 0)
if attention_mask is not None:
attention_mask = pad_for_sequence_parallel(attention_mask, 0)
if use_varlen_attn:
max_seqlen = (
cumulative_len[0][1:] - # noqa: W504
cumulative_len[0][:-1]).max().item()
data_dict = {
'input_ids': input_ids,
'cumulative_len': cumulative_len,
'position_ids': position_ids,
'labels': labels,
'max_seqlen': max_seqlen
}
else:
data_dict = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'position_ids': position_ids,
'labels': labels
}
if has_image:
if all(x.shape == pixel_values[0].shape for x in pixel_values):
pixel_values = torch.stack(pixel_values, dim=0)
data_dict['frames_per_batch'] = frames_per_batch
data_dict['pixel_values'] = pixel_values
if has_pe:
data_dict['image_grid_thw'] = image_grid_thw
if has_fast_image:
if all(x.shape == fast_pixel_values[0].shape for x in fast_pixel_values):
fast_pixel_values = torch.stack(fast_pixel_values, dim=0)
data_dict['fast_pixel_values'] = fast_pixel_values
if has_fast_exists:
data_dict['fast_exists'] = fast_exists
if has_vp:
data_dict['vp_overall_mask'] = torch.cat(vp_overall_mask, dim=0)
if has_prompt_mask:
data_dict['prompt_masks'] = prompt_masks
if has_grounding_image:
# if all(x.shape == grounding_pixel_values[0].shape for x in grounding_pixel_values):
# grounding_pixel_values = torch.stack(grounding_pixel_values, dim=0)
data_dict['g_pixel_values'] = grounding_pixel_values
if has_mask:
data_dict['masks'] = object_masks
if has_bboxes:
data_dict['bboxes'] = object_bboxes
if has_points:
data_dict['points'] = prompt_points
if return_hf_format:
return data_dict
else:
return {'data': data_dict, 'data_samples': None}
\ No newline at end of file
import copy
from xtuner.dataset.utils import get_bos_eos_token_ids
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
def video_lisa_encode_fn(
example,
tokenizer,
max_length,
input_ids_with_output=True,
**kwargs
):
"""We only support the following three scenarios:
1. Incremental pretraining dataset.
example['conversation'] = [
{
'input': '',
'output': '### Human: Can you write xxx'
}
]
2. Single-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
}
]
3. Multi-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
},
{
'input': 'Please expand on the second point.',
'output': 'Here is an expanded explanation of the xxx'
}
]
"""
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
is_multi_turn_conversation = len(example['conversation']) > 1
if is_multi_turn_conversation:
assert input_ids_with_output
input_ids, labels = [], []
next_needs_bos_token = True
for single_turn_conversation in example['conversation']:
input = single_turn_conversation['input']
input_encode = tokenizer.encode(input, add_special_tokens=False)
if next_needs_bos_token:
input_ids += bos_token_id
labels += [IGNORE_INDEX] * len(bos_token_id)
input_ids += input_encode
labels += [IGNORE_INDEX] * len(input_encode)
if input_ids_with_output:
# Add output
output_with_loss = single_turn_conversation.get(
'output_with_loss', True)
output = single_turn_conversation['output']
output_encode = tokenizer.encode(output, add_special_tokens=False)
input_ids += output_encode
if output_with_loss:
labels += copy.deepcopy(output_encode)
else:
labels += [IGNORE_INDEX] * len(output_encode)
# Add EOS_TOKEN (with loss)
if single_turn_conversation.get('need_eos_token', True):
next_needs_bos_token = True
input_ids += eos_token_id
if output_with_loss:
labels += copy.deepcopy(eos_token_id)
else:
labels += [IGNORE_INDEX] * len(eos_token_id)
else:
next_needs_bos_token = False
# Add SEP (without loss)
sep = single_turn_conversation.get('sep', '')
if sep != '':
sep_encode = tokenizer.encode(sep, add_special_tokens=False)
input_ids += sep_encode
labels += [IGNORE_INDEX] * len(sep_encode)
if len(input_ids) > max_length:
input_ids = input_ids[:max_length]
labels = labels[:max_length]
return {'input_ids': input_ids, 'labels': labels}
def video_lisa_encode_multi_conv_fn(
example,
tokenizer,
max_length,
input_ids_with_output=True
):
"""We only support the following three scenarios:
1. Incremental pretraining dataset.
example['conversation'] = [
{
'input': '',
'output': '### Human: Can you write xxx'
}
]
2. Single-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
}
]
3. Multi-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
},
{
'input': 'Please expand on the second point.',
'output': 'Here is an expanded explanation of the xxx'
}
]
"""
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
assert not input_ids_with_output
input_id_list = []
for conv in example['conversation']:
input_ids = []
next_needs_bos_token = True
for single_turn_conversation in conv:
input = single_turn_conversation['input']
input_encode = tokenizer.encode(input, add_special_tokens=False)
if next_needs_bos_token:
input_ids += bos_token_id
input_ids += input_encode
if len(input_ids) > max_length:
input_ids = input_ids[:max_length]
input_id_list.append(input_ids)
return {'input_ids': input_id_list}
import numpy as np
import random
from xtuner.utils import DEFAULT_IMAGE_TOKEN
GCG_QUESTIONS = [
DEFAULT_IMAGE_TOKEN + 'Could you please give me a brief description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
DEFAULT_IMAGE_TOKEN + 'Can you provide a brief description of the this image? Please output with interleaved segmentation masks for the corresponding phrases.',
DEFAULT_IMAGE_TOKEN + 'Please briefly describe the contents of the image. Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
DEFAULT_IMAGE_TOKEN + 'Could you give a brief explanation of what can be found within this picture? Please output with interleaved segmentation masks for the corresponding phrases.',
DEFAULT_IMAGE_TOKEN + 'Could you give me an brief explanation of this picture? Please respond with interleaved segmentation masks for the corresponding phrases.',
DEFAULT_IMAGE_TOKEN + 'Could you provide me with a briefly analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer.',
]
def refcocog_parse_annotations(example):
# example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str}
annotations = {'labels': [], 'caption': [], 'masks': [], 'tokens_positive': [],
'file_name': example['img_file_name'], 'image': example['img_file_name']}
orig_caption = example['caption'].strip('"').strip()
annotations['caption'] = orig_caption.lower()
for detail in example['refs']:
phrase = detail['sentence']
if phrase.lower() in annotations['caption']:
annotations['labels'].append(phrase)
index = annotations['caption'].find(phrase)
end_index = index + len(phrase) if index != -1 else -1
annotations['tokens_positive'].append([index, end_index])
# still polygon or rle
annotations['masks'].append(detail["segmentation"])
# Sort tokens_positive and corresponding lists
tokens_positive = annotations['tokens_positive']
sorted_indices = sorted(range(len(tokens_positive)), key=lambda i: tokens_positive[i][0])
annotations['tokens_positive'] = [tokens_positive[i] for i in sorted_indices]
annotations['masks'] = [annotations['masks'][i] for i in sorted_indices]
annotations['labels'] = [annotations['labels'][i] for i in sorted_indices]
# Trimming overlapping intervals
for i in range(len(tokens_positive)):
for j in range(i + 1, len(tokens_positive)):
# If there is overlap
if tokens_positive[i][1] >= tokens_positive[j][0]:
# Modify the end index of phrase i to be one less than the start index of phrase j
tokens_positive[i][1] = tokens_positive[j][0] - 1
# Modify the phrases to reflect the change in indices
annotations['labels'][i] = orig_caption[tokens_positive[i][0]:tokens_positive[i][1] + 1]
break # Exit inner loop since i was modified
return annotations
def refcocog_conversation(caption, tokens_positive):
# insert <p> </p> and [seg] to caption and select a question
question = random.choice(GCG_QUESTIONS).strip()
# Prepare caption with tags
def tag_caption(caption, tokens):
for start, end in sorted(tokens, key=lambda x: x[0], reverse=True):
caption = f"{caption[:start]}<p> {caption[start:end]} </p> [SEG]{caption[end:]}"
return caption
detailed_answer = tag_caption(caption, tokens_positive)
conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}]
return conversations
def refcocog_preprocess(example):
data_labels = example['labels']
masks = example['masks']
caption = example['caption']
tokens_positive = example['tokens_positive']
# Function to sort elements based on the start index of each phrase
def sort_by_start_index(items, order):
return [items[i] for i in order]
# Sort phrases based on their appearance in the sentence
phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
masks = sort_by_start_index(masks, phrase_order)
data_labels = sort_by_start_index(data_labels, phrase_order)
tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
conversations = refcocog_conversation(caption, tokens_positive)
example['conversations'] = conversations
example['labels'] = data_labels
example['masks'] = masks
example['tokens_positive'] = tokens_positive
return example
def glamm_refcocog_map_fn(example):
# example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str}
example = refcocog_parse_annotations(example)
# example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file
example = refcocog_preprocess(example)
# do llava preprocess
messages = example['conversations']
input = ''
conversation = []
while messages and messages[0]['from'] == 'gpt':
# Skip the first one if it is from gpt
messages = messages[1:]
for msg in messages:
if msg['from'] == 'human':
if DEFAULT_IMAGE_TOKEN in msg['value']:
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
'').strip()
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
msg['value'] = msg['value'].strip()
input += msg['value']
elif msg['from'] == 'gpt':
conversation.append({'input': input, 'output': msg['value']})
input = ''
else:
raise NotImplementedError
example.update({'conversation': conversation})
return example
def grandf_parse_annotations(example):
image_path = example['file_name']
annotations = {
'labels': [], 'caption': [], 'masks': [],
'tokens_positive': [], 'file_name': image_path,
'image': image_path}
annotations['caption'] = example['caption'].strip('"').strip()
for word, grounding in example["groundings"].items():
if grounding is None:
continue
annotations['labels'].append(word)
annotations['tokens_positive'].append(grounding["token_positives"])
annotations['masks'].append(grounding["rle_masks"])
return annotations
def grandf_conversation(caption, tokens_positive):
question = random.choice(GCG_QUESTIONS).strip()
# Prepare caption with tags
def tag_caption(caption, tokens):
for start, end in sorted(tokens, key=lambda x: x[0], reverse=True):
caption = f"{caption[:start]}<p> {caption[start:end]} </p> [SEG]{caption[end:]}"
return caption
detailed_answer = tag_caption(caption, tokens_positive)
conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}]
return conversations
def grandf_preprocess(example):
data_labels = example['labels']
masks = example['masks']
caption = example['caption']
tokens_positive = example['tokens_positive']
# Function to sort elements based on the start index of each phrase
def sort_by_start_index(items, order):
return [items[i] for i in order]
# Sort phrases based on their appearance in the sentence
phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
masks = sort_by_start_index(masks, phrase_order)
data_labels = sort_by_start_index(data_labels, phrase_order)
tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
conversations = grandf_conversation(caption, tokens_positive)
example['conversations'] = conversations
example['labels'] = data_labels
example['masks'] = masks
example['tokens_positive'] = tokens_positive
return example
def glamm_granf_map_fn(example):
# example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str",
# "groundings": {ground_words: {'token_positives', 'rle_masks', }}}
example = grandf_parse_annotations(example)
# example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file
example = grandf_preprocess(example)
# do llava preprocess
messages = example['conversations']
input = ''
conversation = []
while messages and messages[0]['from'] == 'gpt':
# Skip the first one if it is from gpt
messages = messages[1:]
for msg in messages:
if msg['from'] == 'human':
if DEFAULT_IMAGE_TOKEN in msg['value']:
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
'').strip()
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
msg['value'] = msg['value'].strip()
input += msg['value']
elif msg['from'] == 'gpt':
conversation.append({'input': input, 'output': msg['value']})
input = ''
else:
raise NotImplementedError
example.update({'conversation': conversation})
return example
glamm_openpsg_map_fn = glamm_granf_map_fn
def flickr_parse_annotations(example):
annotations = {'bboxes': [], 'labels': [], 'bboxes_ignore': [], 'caption': example['caption'], 'masks': [],
'tokens_positive': [], 'image': example['file_name']}
ann_info = example["ann_info"]
for ann in ann_info:
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
inter_w = max(0, min(x1 + w, example['width']) - max(x1, 0))
inter_h = max(0, min(y1 + h, example['height']) - max(y1, 0))
if inter_w * inter_h == 0 or ann['area'] <= 0 or w < 1 or h < 1:
continue
bbox = [x1, y1, x1 + w, y1 + h]
annotations['bboxes'].append(bbox)
tokens_positive = ann['tokens_positive']
gt_label = [example['caption'][span[0]:span[1]] for span in tokens_positive]
annotations['labels'].append(gt_label[0])
annotations['tokens_positive'].append(tokens_positive[0])
rle = ann['sam_mask']
annotations['masks'].append(rle)
# Convert bounding boxes to numpy arrays
annotations['bboxes'] = np.array(annotations['bboxes'], dtype=np.float32) if annotations[
'bboxes'] else np.zeros((0, 4), dtype=np.float32)
annotations['bboxes_ignore'] = np.array(annotations['bboxes_ignore'], dtype=np.float32) if annotations[
'bboxes_ignore'] else np.zeros((0, 4), dtype=np.float32)
return annotations
def flickr_preprocess(example):
data_labels = example['labels']
masks = example['masks']
caption = example['caption']
tokens_positive = example['tokens_positive']
# Function to sort elements based on the start index of each phrase
def sort_by_start_index(items, order):
return [items[i] for i in order]
# Sort phrases based on their appearance in the sentence
phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
masks = sort_by_start_index(masks, phrase_order)
data_labels = sort_by_start_index(data_labels, phrase_order)
tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
conversations = grandf_conversation(caption, tokens_positive)
example['conversations'] = conversations
example['labels'] = data_labels
example['masks'] = masks
example['tokens_positive'] = tokens_positive
return example
def glamm_flickr_map_fn(example):
# example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str",
# "groundings": {ground_words: {'token_positives', 'rle_masks', }}}
example = flickr_parse_annotations(example)
example = flickr_preprocess(example)
# do llava preprocess
messages = example['conversations']
input = ''
conversation = []
while messages and messages[0]['from'] == 'gpt':
# Skip the first one if it is from gpt
messages = messages[1:]
for msg in messages:
if msg['from'] == 'human':
if DEFAULT_IMAGE_TOKEN in msg['value']:
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
'').strip()
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
msg['value'] = msg['value'].strip()
input += msg['value']
elif msg['from'] == 'gpt':
conversation.append({'input': input, 'output': msg['value']})
input = ''
else:
raise NotImplementedError
example.update({'conversation': conversation})
return example
import numpy as np
import random
from xtuner.utils import DEFAULT_IMAGE_TOKEN
GCG_QUESTIONS = [
DEFAULT_IMAGE_TOKEN + 'Could you please give me a brief description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
DEFAULT_IMAGE_TOKEN + 'Can you provide a brief description of the this image? Please output with interleaved segmentation masks for the corresponding phrases.',
DEFAULT_IMAGE_TOKEN + 'Please briefly describe the contents of the image. Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
DEFAULT_IMAGE_TOKEN + 'Could you give a brief explanation of what can be found within this picture? Please output with interleaved segmentation masks for the corresponding phrases.',
DEFAULT_IMAGE_TOKEN + 'Could you give me an brief explanation of this picture? Please respond with interleaved segmentation masks for the corresponding phrases.',
DEFAULT_IMAGE_TOKEN + 'Could you provide me with a briefly analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer.',
]
def grand_parse_annotations(example):
annotations = {
'caption': [], 'masks': [],
'tokens_positive': [], 'labels': []}
annotations['caption'] = example['dense_caption']['caption'].strip('"').strip()
object_infos = example['dense_caption']['details']
all_seg_objects_dict = {}
for seg_object_dict in example["objects"]:
all_seg_objects_dict[seg_object_dict['id']] = seg_object_dict
for seg_object_dict in example["floating_objects"]:
all_seg_objects_dict[seg_object_dict['id']] = seg_object_dict
for object_info in object_infos:
ids = object_info["ids"]
if object_info["tokens_positive"] is None:
continue
annotations['labels'].append(object_info["phrase"])
annotations['tokens_positive'].append(object_info["tokens_positive"])
_masks = []
for _id in ids:
_masks.append(all_seg_objects_dict[_id]['segmentation'])
annotations['masks'].append(_masks)
return annotations
def grand_conversation(caption, tokens_positive):
question = random.choice(GCG_QUESTIONS).strip()
# Prepare caption with tags
def tag_caption(caption, tokens):
for start, end in sorted(tokens, key=lambda x: x[0], reverse=True):
caption = f"{caption[:start]}<p> {caption[start:end]} </p> [SEG]{caption[end:]}"
return caption
detailed_answer = tag_caption(caption, tokens_positive)
conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}]
return conversations
def grand_preprocess(example):
data_labels = example['labels']
masks = example['masks']
caption = example['caption']
tokens_positive = example['tokens_positive']
# Function to sort elements based on the start index of each phrase
def sort_by_start_index(items, order):
return [items[i] for i in order]
# Sort phrases based on their appearance in the sentence
phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
masks = sort_by_start_index(masks, phrase_order)
data_labels = sort_by_start_index(data_labels, phrase_order)
tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
conversations = grand_conversation(caption, tokens_positive)
example['conversations'] = conversations
example['labels'] = data_labels
example['masks'] = masks
example['tokens_positive'] = tokens_positive
return example
def glamm_grand_map_fn(example):
# example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str",
# "groundings": {ground_words: {'token_positives', 'rle_masks', }}}
example = grand_parse_annotations(example)
# example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file
example = grand_preprocess(example)
# do llava preprocess
messages = example['conversations']
input = ''
conversation = []
while messages and messages[0]['from'] == 'gpt':
# Skip the first one if it is from gpt
messages = messages[1:]
for msg in messages:
if msg['from'] == 'human':
if DEFAULT_IMAGE_TOKEN in msg['value']:
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
'').strip()
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
msg['value'] = msg['value'].strip()
input += msg['value']
elif msg['from'] == 'gpt':
conversation.append({'input': input, 'output': msg['value']})
input = ''
else:
raise NotImplementedError
example.update({'conversation': conversation})
return example
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image,
min_num=1,
max_num=6,
image_size=448,
use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = {(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1) for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
target_ratios, orig_width,
orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = ((i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
\ No newline at end of file
import copy
import random
import glob
import json
import logging
import os
from typing import Literal
import torch
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from pycocotools.coco import COCO
from pycocotools import mask as mask_utils
from xtuner.registry import BUILDER
from xtuner.utils import IGNORE_INDEX
from xtuner.dataset.utils import encode_fn
from xtuner.dataset.map_fns import llava_map_fn
from projects.glamm.datasets.utils.utils import expand2square
from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST
from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from .utils import dynamic_preprocess
class InfinityMMDataset(Dataset):
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def __init__(self,
tokenizer,
data_path,
prompt_template,
special_tokens=None,
max_length=8192,
offline_save_path='./work_dirs/infinityMM.json',
):
self.offline_save_path = offline_save_path
self.tokenizer = BUILDER.build(tokenizer)
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self._system = ''
self.template = prompt_template
self.max_length = max_length
self.min_dynamic_patch = 1
self.max_dynamic_patch = 12
self.downsample_ratio = 0.5
self.image_size = 448
self.use_thumbnail = True
patch_size = 14
self.patch_token = int(
(self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
self.transformer = T.Compose([
T.Lambda(lambda img: img.convert('RGB')
if img.mode != 'RGB' else img),
T.Resize((self.image_size, self.image_size),
interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
])
self.data = self._load_annotations(data_path)
self._max_refetch = 1000
def _load_annotations(self, data_path):
if os.path.exists(self.offline_save_path):
with open(self.offline_save_path, 'r') as f:
ret = json.load(f)
print(f"Load InfinityMM file list from {self.offline_save_path}, {len(ret)} items !!!")
return ret
sub_folders = []
for sub_folder in os.listdir(data_path):
if '.' not in sub_folder:
# a folder
if "LVIS_111k" in sub_folder:
# special case, have subsub folder
subsub_folders = os.listdir(os.path.join(data_path, sub_folder))
for subsub_folder in subsub_folders:
sub_folders.append(os.path.join(data_path, sub_folder, subsub_folder))
else:
sub_folders.append(os.path.join(data_path, sub_folder))
all_jsons = []
for sub_folder in sub_folders:
print(f"Processing {sub_folder} !!!")
_files = os.listdir(sub_folder)
_num = 0
for _file in _files:
if '.json' in _file:
_json_path = os.path.join(sub_folder, _file)
_num += 1
all_jsons.append(os.path.join(sub_folder, _file))
print(f"Finished {sub_folder} has {_num} items.")
with open(self.offline_save_path, 'w') as f:
json.dump(all_jsons, f)
return all_jsons
def __getitem__(self, index):
for _ in range(self._max_refetch + 1):
data = self.prepare_data(index)
# Broken images may cause the returned data to be None
if data is None:
index = self._rand_another()
continue
return data
def __len__(self):
return len(self.data)
@property
def modality_length(self):
self.group_length = []
for data_dict in self.data:
self.group_length.append(100)
return self.group_length
@property
def length(self):
group_length = np.array(self.group_length)
group_length = np.abs(group_length).tolist()
return group_length
def prepare_data(self, index):
data_path = self.data[index]
with open(data_path, 'r') as f:
data_dict = json.load(f)
if 'image' in data_dict.keys():
data_dict['image'] = data_path.replace('.json', '.jpg')
if data_dict is None:
return None
out_data_dict = {}
if data_dict.get('image', None) is not None:
image_file = data_dict['image']
try:
image = Image.open(image_file).convert('RGB')
except Exception as e:
print(f'Error: {e}', flush=True)
print_log(f'Error: {e}', logger='current')
return None
images = dynamic_preprocess(image, self.min_dynamic_patch,
self.max_dynamic_patch,
self.image_size, self.use_thumbnail)
pixel_values = [self.transformer(image) for image in images]
pixel_values = torch.stack(pixel_values)
out_data_dict['pixel_values'] = pixel_values
num_image_tokens = pixel_values.shape[0] * self.patch_token
image_token_str = f'{self.IMG_START_TOKEN}' \
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
f'{self.IMG_END_TOKEN}'
token_dict = self.get_inputid_labels(
data_dict['conversations'], image_token_str)
out_data_dict.update(token_dict)
else:
token_dict = self.get_inputid_labels(
data_dict['conversations'], None)
out_data_dict.update(token_dict)
out_data_dict['pixel_values'] = torch.zeros(
1, 3, self.image_size, self.image_size)
return out_data_dict
def _rand_another(self) -> int:
return np.random.randint(0, len(self.data))
def get_inputid_labels(self, conversations, image_token_str) -> dict:
input = ''
out_conversation = []
while conversations and conversations[0]['from'] == 'gpt':
# Skip the first one if it is from gpt
conversations = conversations[1:]
for i, msg in enumerate(conversations):
if msg['from'] == 'human':
# change to 1 image
if '<image>' in msg['value']:
msg['value'] = msg['value'].replace('<image>\n', '').replace('<image>', '')
if i == 0:
msg['value'] = "<image>\n" + msg['value']
if image_token_str is None and '<image>' in msg['value']:
msg['value'] = msg['value'].replace('<image>', '')
if '<image>' in msg['value']:
msg['value'] = msg['value'].replace('<image>', image_token_str).strip()
input += msg['value'].strip()
elif msg['from'] == 'gpt':
out_conversation.append({
'input': input,
'output': msg['value'].strip()
})
input = ''
else:
raise NotImplementedError
input_ids, labels = [], []
for i, single_turn_conversation in enumerate(out_conversation):
input = single_turn_conversation.get('input', '')
if input is None:
input = ''
input_text = self.template.INSTRUCTION.format(
input=input, round=i + 1)
if i == 0:
if self._system != '' and self._system is not None:
system = self.template.SYSTEM.format(system=self._system)
input_text = system + input_text
input_encode = self.tokenizer.encode(
input_text, add_special_tokens=True)
else:
input_encode = self.tokenizer.encode(
input_text, add_special_tokens=False)
input_ids += input_encode
labels += [IGNORE_INDEX] * len(input_encode)
output_text = single_turn_conversation.get('output', '')
if self.template.get('SUFFIX', None):
output_text += self.template.SUFFIX
output_encode = self.tokenizer.encode(
output_text, add_special_tokens=False)
input_ids += output_encode
labels += copy.deepcopy(output_encode)
if len(input_ids) > self.max_length:
input_ids = input_ids[:self.max_length]
labels = labels[:self.max_length]
print_log(
f'Warning: input_ids length({len(input_ids)}) '
f'is longer than max_length, cut to {self.max_length}',
logger='current')
return {'input_ids': input_ids, 'labels': labels}
class LLaVADataset(Dataset):
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def __init__(self,
tokenizer,
data_path,
prompt_template,
special_tokens=None,
image_folder=None,
max_length=8192,
arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
preprocessor=None,
skip_pure_text=False,
):
self.tokenizer = BUILDER.build(tokenizer)
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.image_folder = image_folder
self.template = prompt_template
self.max_length = max_length
self._system = ''
self.arch_type = arch_type
self.min_dynamic_patch = 1
self.max_dynamic_patch = 12
self.downsample_ratio = 0.5
if self.arch_type == 'llava':
self.downsample_ratio = 1
self.image_size = 448
if self.arch_type == 'llava':
self.image_size = 336
self.use_thumbnail = True
patch_size = 14
self.patch_token = int(
(self.image_size // patch_size)**2 * (self.downsample_ratio**2))
if self.arch_type == 'qwen':
self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
self.IMG_START_TOKEN = '<|vision_start|>'
self.IMG_END_TOKEN = '<|vision_end|>'
elif self.arch_type == 'llava':
self.IMG_CONTEXT_TOKEN = '<image>'
self.IMG_START_TOKEN = ''
self.IMG_END_TOKEN = ''
if preprocessor is None:
self.transformer = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
])
self.preprocessor = None
else:
self.transformer = None
self.preprocessor = BUILDER.build(preprocessor)
self.data = self._load_annotations(data_path, image_folder)
self._max_refetch = 1000
self.skip_pure_text = skip_pure_text
def _load_annotations(self, data_path, image_folder=None):
data = json.load(open(data_path))
return data
def __getitem__(self, index):
for _ in range(self._max_refetch + 1):
data = self.prepare_data(index)
# Broken images may cause the returned data to be None
if data is None:
index = self._rand_another()
continue
return data
def __len__(self):
return len(self.data)
@property
def modality_length(self):
self.group_length = []
for data_dict in self.data:
self.group_length.append(100)
return self.group_length
@property
def length(self):
group_length = np.array(self.group_length)
group_length = np.abs(group_length).tolist()
return group_length
def prepare_data(self, index):
data_dict: dict = self.data[index]
if data_dict is None:
return None
out_data_dict = {}
if self.skip_pure_text and data_dict.get('image', None) is None:
return None
if data_dict.get('image', None) is not None:
image_file = os.path.join(self.image_folder, data_dict['image'])
try:
image = Image.open(image_file).convert('RGB')
except Exception as e:
print(f'Error: {e}', flush=True)
print_log(f'Error: {e}', logger='current')
return None
if self.preprocessor is not None:
# images = dynamic_preprocess(image, self.min_dynamic_patch,
# self.max_dynamic_patch,
# self.image_size, self.use_thumbnail)
images = [image]
if self.arch_type == 'qwen':
_data_dict = self.preprocessor(images, do_resize=True)
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
_data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
num_image_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
elif self.arch_type == 'llava':
_data_dict = self.preprocessor(images, do_resize=True, size=(self.image_size, self.image_size))
_data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
num_image_tokens = _data_dict['pixel_values'].shape[0] * self.patch_token
else:
raise NotImplementedError
out_data_dict.update(_data_dict)
else:
images = dynamic_preprocess(image, self.min_dynamic_patch,
self.max_dynamic_patch,
self.image_size, self.use_thumbnail)
pixel_values = [self.transformer(image) for image in images]
pixel_values = torch.stack(pixel_values)
out_data_dict['pixel_values'] = pixel_values
num_image_tokens = pixel_values.shape[0] * self.patch_token
image_token_str = f'{self.IMG_START_TOKEN}' \
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
f'{self.IMG_END_TOKEN}'
token_dict = self.get_inputid_labels(
data_dict['conversations'], image_token_str)
out_data_dict.update(token_dict)
else:
token_dict = self.get_inputid_labels(
data_dict['conversations'], None)
out_data_dict.update(token_dict)
out_data_dict['pixel_values'] = torch.zeros(
1, 3, self.image_size, self.image_size)
return out_data_dict
def _rand_another(self) -> int:
return np.random.randint(0, len(self.data))
def get_inputid_labels(self, conversations, image_token_str) -> dict:
input = ''
out_conversation = []
while conversations and conversations[0]['from'] == 'gpt':
# Skip the first one if it is from gpt
conversations = conversations[1:]
for msg in conversations:
if msg['from'] == 'human':
if image_token_str is None and '<image>' in msg['value']:
msg['value'] = msg['value'].replace('<image>', '')
if '<image>' in msg['value']:
msg['value'] = msg['value'].replace('<image>', image_token_str).strip()
input += msg['value'].strip()
elif msg['from'] == 'gpt':
out_conversation.append({
'input': input,
'output': msg['value'].strip()
})
input = ''
else:
raise NotImplementedError
input_ids, labels = [], []
for i, single_turn_conversation in enumerate(out_conversation):
input = single_turn_conversation.get('input', '')
if input is None:
input = ''
input_text = self.template.INSTRUCTION.format(
input=input, round=i + 1)
if i == 0:
if self._system != '' and self._system is not None:
system = self.template.SYSTEM.format(system=self._system)
input_text = system + input_text
input_encode = self.tokenizer.encode(
input_text, add_special_tokens=True)
else:
input_encode = self.tokenizer.encode(
input_text, add_special_tokens=False)
input_ids += input_encode
labels += [IGNORE_INDEX] * len(input_encode)
output_text = single_turn_conversation.get('output', '')
if self.template.get('SUFFIX', None):
output_text += self.template.SUFFIX
output_encode = self.tokenizer.encode(
output_text, add_special_tokens=False)
input_ids += output_encode
labels += copy.deepcopy(output_encode)
if len(input_ids) > self.max_length:
input_ids = input_ids[:self.max_length]
labels = labels[:self.max_length]
print_log(
f'Warning: input_ids length({len(input_ids)}) '
f'is longer than max_length, cut to {self.max_length}',
logger='current')
return {'input_ids': input_ids, 'labels': labels}
if __name__ == '__main__':
from transformers import CLIPImageProcessor, AutoTokenizer
from third_parts.segment_anything.utils.transforms import ResizeLongestSide
pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=llm_name_or_path)
image_processor = dict(
type=CLIPImageProcessor.from_pretrained,
pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
extra_image_processor = dict(
type=ResizeLongestSide,
target_length=1024,
)
from xtuner.utils.templates import PROMPT_TEMPLATE
prompt_template = PROMPT_TEMPLATE.vicuna
from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
dataset = LLaVADataset(
tokenizer=tokenizer,
data_path='data/llava_data/LLaVA-Instruct-150K/llava_instruct_150k.json',
prompt_template=prompt_template,
special_tokens=['[SEG]'],
image_folder='data/coco/train2017/',
)
for i in range(1000):
dataset[i]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment