Commit 0063a668 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
import torch
import torchvision
import re
import cv2
import numpy as np
import os
import yaml
from PIL import Image
from data.utils.visual_trace import visual_trace
from data.utils.som_tom import som_prompting, tom_prompting
import torchvision.io as tv_io
import torchvision
import time
import random
from decord import VideoReader, cpu
class Constructor():
def __init__(self, **kwargs):
self.trace = visual_trace(linewidth=4)
self.mm_use_trace_start_end = kwargs.get('mm_use_trace_start_end', False)
self.mm_use_trace_speed = kwargs.get('mm_use_trace_speed', False)
self.mm_use_image_start_end = kwargs.get('mm_use_image_start_end', False)
self.mm_use_image_history = kwargs.get('mm_use_image_history', False)
self.remove_static_trace_pts = kwargs.get('remove_static_trace_pts', False)
self.show_trace = kwargs.get('show_trace', False)
self.video_reader = kwargs.get('video_reader', 'decord')
if self.mm_use_image_start_end:
self.image_placeholder = '<image_start><image><image_end>\n'
else:
self.image_placeholder = '<image>\n'
def _get_frame(self, video_path, frame_start, frame_pos, size):
if video_path.endswith('.jpg') or video_path.endswith('.png'):
image = Image.open(video_path).resize(size)
return image
if self.video_reader == 'cv2':
video_cap = cv2.VideoCapture(video_path)
num_frames = video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
if frame_start + frame_pos >= num_frames or frame_start + frame_pos < 0:
frame_pos = 0
trials = 0
video_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_start + frame_pos)
while trials < 5:
success, image = video_cap.read()
if success:
break
else:
time.sleep(0.1)
trials += 1
if not success:
print(f"Failed to read video {video_path} at frame {frame_start + frame_pos}")
image = None
else:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image).resize(size)
video_cap.release()
return image
elif self.video_reader == 'decord':
try:
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
num_frames = len(vr)
if frame_start+frame_pos >= num_frames:
frame_pos = 0
frame_idx = [frame_start+frame_pos]
image = vr.get_batch(frame_idx).asnumpy()[0]
# https://github.com/dmlc/decord/issues/208
vr.seek(0)
# convert image to rgb format
image = Image.fromarray(image).resize(size)
return image
except Exception as e:
print(f"Failed to read video {video_path} at frame {frame_start + frame_pos}")
return None
def _process_gpt_response(self, gpt_response, task_description):
"""
Process the gpt_response
"""
gpt_response = gpt_response.replace('What you see', 'What I see')
gpt_response = gpt_response.replace('you see ', '').replace('You see ', '')
gpt_response = gpt_response.replace('you', 'the person')
gpt_response = gpt_response.replace('your', '')
gpt_response = gpt_response.replace('In the first image, ', '')
# gpt_response = gpt_response.replace('What the person should do next', 'What you should do next')
# gpt_response = gpt_response.replace('What you should do next', 'What you are doing')
gpt_response = gpt_response.replace('personr', 'person\'s')
# remove all str (marks) from the gpt_response
gpt_response = re.sub(r' \([^)]*\)', '', gpt_response)
gpt_response = gpt_response if len(gpt_response) > 0 else task_description
gpt_response = gpt_response.replace('camera wearer', 'person')
return gpt_response
def _construct_conv_semantic(self, item, gpt_response, num_image_tokens=1):
"""
Construct conversations for semantic (language) prediction
"""
image_placeholder = ''.join([self.image_placeholder]*num_image_tokens)
# model task 1: ask model to briefly describe the current image - understand the present
if item['dataset_tag'] in ['ego4d', 'sthv2']:
conv_user = (
f'{image_placeholder}\nWhat is the person doing in the image?\n'
)
conv_gpt = gpt_response + '\n'
gpt_response_todo = gpt_response
elif item['dataset_tag'] == 'human_instruction':
# for human instruction, it is narration
conv_user = (
f'{image_placeholder}\nThe person is doing some task in the image. Guess what is the person saying?\n'
)
conv_gpt = gpt_response + '\n'
gpt_response_todo = gpt_response
elif item['dataset_tag'] in ['epic']:
gpt_response_see = gpt_response.split('What the person should do next')[0].replace('#','').replace('*','').replace('What I see:', '').strip()
conv_user = (
f'{image_placeholder}\nWhat do you see in the image?\n'
)
conv_gpt = gpt_response_see + '\n'
gpt_response_todo = gpt_response.split('What the person should do next')[1].replace('#','').replace('*', '').replace(':','').strip()
elif item['dataset_tag'] in ['openx_magma']:
conv_user = (
f'{image_placeholder}\nWhat is the robot doing in the image?\n'
)
conv_gpt = gpt_response + '\n'
gpt_response_todo = gpt_response
return conv_user, conv_gpt, gpt_response_todo
def _construct_conv_som(self, item, image, visual_traces, frame_pos, pos_traces_to_mark=None, neg_traces_to_mark=None, normalize=True):
"""
Construct conversations for spatial prediction
"""
if pos_traces_to_mark is None or neg_traces_to_mark is None:
pred_tracks = visual_traces['pred_tracks']
pred_visibility = visual_traces['pred_visibility']
# randomly sample pos_tracks and neg_tracks
num_clusters_pos = torch.randint(2, 6, (1,)).item()
num_clusters_neg = torch.randint(6, 15, (1,)).item()
pos_tracks = pred_tracks[:,frame_pos:,torch.randint(0, pred_tracks.size(2), (num_clusters_pos,))]
neg_tracks = pred_tracks[:,frame_pos:,torch.randint(0, pred_tracks.size(2), (num_clusters_neg,))]
image, pos_traces_to_mark, neg_traces_to_mark, pos_mark_ids, neg_mark_ids, all_idx = \
som_prompting(image, pos_tracks, neg_tracks, draw_som_positive=True, draw_som_negative=True)
conv_user = (
f"{self.image_placeholder}\nThe image is split into {self.spatial_quant_size}x{self.spatial_quant_size} grids, and labeled with numeric marks.\n"
f"Please locate all the numerical marks in the image.\n"
)
# combine pos_traces_to_mark and neg_traces_to_mark
pos_traces_to_mark.update(neg_traces_to_mark)
# sort pos_traces_to_mark by the key
pos_traces_to_mark = dict(sorted(pos_traces_to_mark.items()))
marks_pos = []
for key, val in pos_traces_to_mark.items():
trace = val[0]
if normalize:
x = int(self.spatial_quant_size * trace[0, 0] / image.size[0])
y = int(self.spatial_quant_size * trace[0, 1] / image.size[1])
else:
x = int(trace[0, 0])
y = int(trace[0, 1])
val_str = f"[{x},{y}]"
marks_pos.append(f'Mark {key} at {val_str}')
conv_gpt = ". ".join(marks_pos) + '\n'
return conv_user, conv_gpt, image
def _construct_conv_tom(self, item, video_path, visual_traces):
"""
Construct conversations for spatial-temporal prediction
"""
def _construct_conv(self, item, video_path, visual_traces):
# NOTE: for pretraining on video, we always set num_crops to 1 to save memory cost
item['num_crops'] = 1
if video_path is None and visual_traces is None:
dummy_conversations = []
dummy_conversations.append({'from': 'human', 'value': f"{self.image_placeholder}\nWhat is in this image?"})
dummy_conversations.append({'from': 'gpt', 'value': "This is a blank image."})
item['conversations'] = dummy_conversations
item['image'] = None
return item
if 'image_size' not in item:
assert '(height,width)' in item, f"image_size not in item and (height,width) not in item"
item['image_size'] = item['(height,width)'][::-1]
if isinstance(item['image_size'][0], torch.Tensor):
width, height = item['image_size'][0].item(), item['image_size'][1].item()
frame_start, frame_end = item['frame_interval'][0].item(), item['frame_interval'][1].item()
task_description = item['global_instructions'][0]
gpt_response = item['gpt_response'][0]
else:
width, height = item['image_size']
frame_start, frame_end = item['frame_interval']
task_description = item['global_instructions']
gpt_response = item['gpt_response']
gpt_response = self._process_gpt_response(gpt_response, task_description)
if self.mm_use_image_history:
# randomly sample at most 3 unique indices in range [0, frame_start)
frame_idx = torch.randperm(frame_start)[:3].sort().values.tolist() + [frame_start]
else:
frame_idx = [frame_start]
item['image'] = self._get_frames_with_idx(video_path, frame_idx, (width, height))
if item['image'] is None:
dummy_conversations = []
dummy_conversations.append({'from': 'human', 'value': f"{self.image_placeholder}\nWhat is in this image?"})
dummy_conversations.append({'from': 'gpt', 'value': "This is a blank image."})
item['conversations'] = dummy_conversations
return item
conv_user, conv_gpt, gpt_response_todo = self._construct_conv_semantic(item, gpt_response, len(item['image']))
item['conversations'] = [
{'from': 'human', 'value': conv_user},
{'from': 'gpt', 'value': conv_gpt}
]
if not self.use_som_tom or random.random() < 0.2:
return item
if visual_traces is None:
return item
if len(visual_traces['pred_tracks'].shape) == 3:
visual_traces['pred_tracks'] = visual_traces['pred_tracks'][None]
if len(visual_traces['pred_visibility'].shape) == 2:
visual_traces['pred_visibility'] = visual_traces['pred_visibility'][None]
# model task 2: ask model the predict the movements of the person and/or the object - visual action for the future
# sort pos_traces_to_mark by the key
# calculate the trace length for each step
track_length = torch.norm(visual_traces['pred_tracks'][:, 1:] - visual_traces['pred_tracks'][:, :-1], dim=3).mean(2)
# accum_sum track_length
accum_sum = torch.cumsum(track_length, dim=1) / (1e-5 + track_length.sum(1)[:, None])
# find last position
frame_rightmost = min(max(1, (accum_sum[0] < self.settings['trace_planner']['step_rightmost_ratio']).int().sum().item()), visual_traces['pred_tracks'].shape[1]-1)
# random select a frame position but not the last frame
frame_pos = torch.randint(0, frame_rightmost, (1,)).item()
pred_tracks = visual_traces['pred_tracks'][:, frame_pos:]
pred_visibility = visual_traces['pred_visibility'][:, frame_pos:]
step_to_predict = pred_tracks.size(1)
if step_to_predict == 0:
return item
pred_tracks_history = visual_traces['pred_tracks'][:, :max(1, frame_pos+1)]
pred_visibility_history = visual_traces['pred_visibility'][:, :max(1, frame_pos+1)]
# only keep points that are visible at at least half steps
valid_idx = pred_visibility[0].sum(0) > 0.5*pred_tracks.shape[1]
if valid_idx.sum() <= 1:
image = self._get_frame(video_path, frame_start, 0, (width, height))
conv_user, conv_gpt, image = self._construct_conv_som(item, image, visual_traces, 0)
item['conversations'].append({'from': 'human', 'value': conv_user})
item['conversations'].append({'from': 'gpt', 'value': conv_gpt})
item['image'].append(image)
return item
pred_tracks = pred_tracks[:, :, valid_idx]
pred_visibility = pred_visibility[:, :, valid_idx]
pred_tracks_history = pred_tracks_history[:, :, valid_idx]
pred_visibility_history = pred_visibility_history[:, :, valid_idx]
if self.show_trace:
image = self._get_frame(video_path, frame_start, frame_pos, (width, height))
for k in range(0,pred_tracks.shape[1],5):
image_k = self._get_frame(video_path, frame_start, frame_pos+k, (width, height))
if image_k is not None:
# mkdir
if not os.path.exists(f"./release/videos/trace_{item['video'][0].replace('/', '_').replace('.MP4', '')}"):
os.makedirs(f"./release/videos/trace_{item['video'][0].replace('/', '_').replace('.MP4', '')}")
image_k.save(f"./release/videos/trace_{item['video'][0].replace('/', '_').replace('.MP4', '')}/frame_{k}.jpg")
images = [image] * pred_tracks.shape[1]
video = torch.stack([torchvision.transforms.ToTensor()(img) for img in images])[None].float()*255
self.trace.visualizer.save_dir = "./release/videos"
_ = self.trace.visualize(video, pred_tracks, pred_visibility, filename=f"trace_{item['video'][0].replace('/', '_').replace('.MP4', '')}", mode="rainbow")
# calculate the trajectory lenght for pred_tracks
pred_tracks_length = self.trace.visual_trace_length(pred_tracks, pred_visibility, (1, 1)).squeeze(0)
# if 80% of the pred_tracks_length is larger than 2, then there is camera motion
camera_motion = (pred_tracks_length > 1).sum() > 0.8*pred_tracks_length.size(0)
camera_motion = True if item['dataset_tag'] in ['ego4d', 'epic', 'exoego4d'] else camera_motion
start_pos = pred_tracks[:, 0][0]
reference_pts_np = start_pos.cpu().numpy().reshape(-1, 2)
if camera_motion:
# remove camera motion using homography transformation
try:
future_pts_transformed = []
for k in range(1, pred_tracks.shape[1]):
future_pts = pred_tracks[:, k][0]
future_pts_np = future_pts.cpu().numpy().reshape(-1, 2)
try:
(H, status) = cv2.findHomography(future_pts_np, reference_pts_np, cv2.RANSAC, 4.0)
except Exception as e:
continue
future_pts_np_transformed = cv2.perspectiveTransform(future_pts_np.reshape(1, -1, 2), H).reshape(-1, 2)
future_pts_transformed_k = torch.tensor(future_pts_np_transformed, dtype=torch.float32)
future_pts_transformed.append(future_pts_transformed_k)
pred_tracks = torch.stack([start_pos] + future_pts_transformed, dim=0).unsqueeze(0)
except Exception as e:
pass
if pred_tracks_history.size(1) > 0:
try:
history_pts_transformed = []
for k in range(0, pred_tracks_history.shape[1]):
history_pts = pred_tracks_history[:, k][0]
history_pts_np = history_pts.cpu().numpy().reshape(-1, 2)
try:
(H, status) = cv2.findHomography(history_pts_np, reference_pts_np, cv2.RANSAC, 4.0)
except Exception as e:
continue
history_pts_np_transformed = cv2.perspectiveTransform(history_pts_np.reshape(1, -1, 2), H).reshape(-1, 2)
history_pts_transformed_k = torch.tensor(history_pts_np_transformed, dtype=torch.float32)
history_pts_transformed.append(history_pts_transformed_k)
pred_tracks_history = torch.stack(history_pts_transformed, dim=0).unsqueeze(0)
except Exception as e:
pass
# step 2: find positive traces and negative traces
track_length = self.trace.visual_trace_length(pred_tracks, pred_visibility, (1, 1)).squeeze(0)
threshold = 3 # max(track_length.max(), 2) * self.settings['trace_processor']['postive_factor_threshold']
# video is almost static
if (track_length > threshold).sum() <= 1:
image = self._get_frame(video_path, frame_start, 0, (width, height))
conv_user, conv_gpt, image = self._construct_conv_som(item, image, visual_traces, 0)
item['conversations'].append({'from': 'human', 'value': conv_user})
item['conversations'].append({'from': 'gpt', 'value': conv_gpt})
item['image'].append(image)
return item
else:
# find the positive traces and negative traces
pos_tracks = pred_tracks[:, :, track_length > threshold]
pos_visibility = pred_visibility[:, :, track_length > threshold]
pos_tracks_history = pred_tracks_history[:, :, track_length > threshold]
pos_visibility_history = pred_visibility_history[:, :, track_length > threshold]
neg_tracks = pred_tracks[:, :, track_length <= threshold]
neg_tracks_history = pred_tracks_history[:, :, track_length <= threshold]
# clustering for positive traces
# randome sample a number between 2 and self.num_clusters
num_clusters_pos = torch.randint(2, 6, (1,)).item()
pos_sampled_ids = self.trace.cluster_traces_kmeans(pos_tracks, n_clusters=num_clusters_pos, positive=True)
if pos_sampled_ids is None:
image = self._get_frame(video_path, frame_start, 0, (width, height))
conv_user, conv_gpt, image = self._construct_conv_som(item, image, visual_traces, 0)
item['conversations'].append({'from': 'human', 'value': conv_user})
item['conversations'].append({'from': 'gpt', 'value': conv_gpt})
item['image'].append(image)
return item
pos_tracks = pos_tracks[:, :, pos_sampled_ids.bool()]
pos_visibility = pos_visibility[:, :, pos_sampled_ids.bool()]
pos_tracks_history = pos_tracks_history[:, :, pos_sampled_ids.bool()]
pos_visibility_history = pos_visibility_history[:, :, pos_sampled_ids.bool()]
# clustering for negative traces
num_clusters_neg = torch.randint(6, 15, (1,)).item()
neg_sampled_ids = self.trace.cluster_traces_kmeans(neg_tracks, n_clusters=num_clusters_neg)
if neg_sampled_ids is None:
image = self._get_frame(video_path, frame_start, 0, (width, height))
conv_user, conv_gpt, image = self._construct_conv_som(item, image, visual_traces, 0)
item['conversations'].append({'from': 'human', 'value': conv_user})
item['conversations'].append({'from': 'gpt', 'value': conv_gpt})
item['image'].append(image)
return item
neg_tracks = neg_tracks[:, :, neg_sampled_ids.bool()]
neg_tracks_history = neg_tracks_history[:, :, neg_sampled_ids.bool()]
image = self._get_frame(video_path, frame_start, frame_pos, (width, height))
if image is None:
image = self._get_frame(video_path, frame_start, 0, (width, height))
conv_user, conv_gpt, image = self._construct_conv_som(item, image, visual_traces, 0)
item['conversations'].append({'from': 'human', 'value': conv_user})
item['conversations'].append({'from': 'gpt', 'value': conv_gpt})
item['image'].append(image)
return item
# image = tom_prompting(self.trace, image, pos_tracks_history, neg_tracks_history, draw_som_positive=False, draw_som_negative=False)
image, pos_traces_to_mark, neg_traces_to_mark, pos_mark_ids, neg_mark_ids, all_idx = \
som_prompting(image, pos_tracks, neg_tracks, draw_som_positive=True, draw_som_negative=True)
# visualize the traces
if self.show_trace:
images = [image] * pos_tracks.shape[1]
video = torch.stack([torchvision.transforms.ToTensor()(img) for img in images])[None].float()*255
self.trace.visualizer.save_dir = "./release/videos"
_ = self.trace.visualize(video, pos_tracks, pos_visibility, filename=f"tom_{item['video'][0].replace('/', '_').replace('.MP4', '')}", mode="rainbow")
mark_ids = sorted([key for key in pos_traces_to_mark.keys()] + [key for key in neg_traces_to_mark.keys()])
pos_traces_to_mark = dict(sorted(pos_traces_to_mark.items()))
mark_trace_history = ''
mark_trace_future = ''
valid_marks = {}
speeds = {}
for key, val in pos_traces_to_mark.items():
# random select a frame position but not the last frame
# frame_pos = torch.randint(0, trace.size(0)-1, (1,)).item()
trace = val[0]
trace[:, 0] = self.spatial_quant_size * trace[:, 0] / width
trace[:, 1] = self.spatial_quant_size * trace[:, 1] / height
trace_temp = trace.clone()
# remove (almost) static points
trace_temp = self.trace.remove_close_points_tensor(trace_temp, 2)
# remove invisible points
trace_temp = trace_temp[(trace_temp > 0).sum(1) == 2]
if trace_temp.size(0) <= step_to_predict // 4:
continue
# calulate motion speed
# if trace_temp.size(0) < step_to_predict:
# trace_temp = torch.cat([trace_temp, trace_temp[-1].repeat(step_to_predict - trace_temp.size(0), 1)], dim=0)
# elif trace_temp.size(0) > step_to_predict:
# trace_temp = trace_temp[:step_to_predict]
# calcualte speed
speed = torch.norm(trace_temp[1:] - trace_temp[:-1], dim=1).mean()
if torch.isnan(speed):
continue
speeds[key] = speed.item()
if speed < self.settings['trace_processor']['postive_speed_threshold']:
continue
# trace_history = trace[0]
# val_str_history = '[' + ','.join([f'[{x[0]},{x[1]}]' for x in trace_history.tolist()]) + ']'
# mark_trace_history += f'\"Mark {key}\": \"{val_str_history}\"\n'
# round trace_temp
if self.remove_static_trace_pts:
valid_marks[key] = trace_temp.int()
else:
valid_marks[key] = trace.int()
# NOTE: there was a bug here
val_str_future = '[' + ','.join([f'[{x[0]},{x[1]}]' for x in valid_marks[key][1:].tolist()]) + ']'
mark_trace_future += f'\"Mark {key}\": \"{val_str_future}\"\n\n'
if len(mark_trace_future) > 0:
num_future_steps = [val.shape[0]-1 for val in valid_marks.values()]
step_to_predict = max(num_future_steps)
# find the maximal steps from valid_marks
if item['dataset_tag'] != 'human_instruction':
if self.mm_use_trace_speed:
# calculate the average speed of the marks
avg_speed = int(sum(speeds.values()) / len(speeds))
conv_user = (
f"{self.image_placeholder}\nThe image is split into {self.spatial_quant_size}x{self.spatial_quant_size} grids, and labeled with numeric marks {mark_ids}.\n"
f"The person is doing: {gpt_response_todo}. To finish the task, how to move the numerical marks in the image with speed {avg_speed} for the next {step_to_predict} steps?\n"
)
else:
conv_user = (
f"{self.image_placeholder}\nThe image is split into {self.spatial_quant_size}x{self.spatial_quant_size} grids, and labeled with numeric marks {mark_ids}.\n"
f"The person is doing: {gpt_response_todo}. To finish the task, how to move the numerical marks in the image for the next {step_to_predict} steps?\n"
)
else:
if self.mm_use_trace_speed:
# calculate the average speed of the marks
avg_speed = int(sum(speeds.values()) / len(speeds))
conv_user = (
f"{self.image_placeholder}\nThe image is split into {self.spatial_quant_size}x{self.spatial_quant_size} grids, and labeled with numeric marks {mark_ids}.\n"
f"The person is saying: {gpt_response_todo}. To finish the task, how to move the numerical marks in the image with speed {avg_speed} for the next {step_to_predict} steps?\n"
)
else:
conv_user = (
f"{self.image_placeholder}\nThe image is split into {self.spatial_quant_size}x{self.spatial_quant_size} grids, and labeled with numeric marks {mark_ids}.\n"
f"The person is saying: {gpt_response_todo}. To finish the task, how to move the numerical marks in the image for the next {step_to_predict} steps?\n"
)
# if self.mm_use_trace_speed:
# # calculate speed
# formmated_val = '. '.join([f"Mark {key} at [{val[0][0].item()},{val[0][1].item()}] will move {val.shape[0]-1} steps with speed {round(speeds[key])}" for key, val in valid_marks.items()])
# else:
formmated_val = ', '.join([f"Mark {key} at [{val[0][0].item()},{val[0][1].item()}]" for key, val in valid_marks.items()])
if self.mm_use_trace_start_end:
mark_trace_future = f'<trace_start>{mark_trace_future}<trace_end>'
conv_gpt = f"{formmated_val} should be moved, and their future positions are:\n\n{mark_trace_future}"
item['conversations'].append({'from': 'human', 'value': conv_user})
item['conversations'].append({'from': 'gpt', 'value': conv_gpt})
item['image'].append(image)
else:
for key, val in neg_traces_to_mark.items():
trace = val[0]
trace[:, 0] = self.spatial_quant_size * trace[:, 0] / width
trace[:, 1] = self.spatial_quant_size * trace[:, 1] / height
conv_user, conv_gpt, image = self._construct_conv_som(item, image, visual_traces, frame_pos, pos_traces_to_mark, neg_traces_to_mark, normalize=False)
item['conversations'].append({'from': 'human', 'value': conv_user})
item['conversations'].append({'from': 'gpt', 'value': conv_gpt})
item['image'].append(image)
import pdb; pdb.set_trace()
return item
def _get_frames(self, video_path, frame_start, frame_end, size):
try:
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
num_frames = len(vr)
if frame_end >= num_frames:
frame_end = num_frames - 1
frame_idx = list(range(frame_start, frame_end))
images = vr.get_batch(frame_idx).asnumpy()
# https://github.com/dmlc/decord/issues/208
vr.seek(0)
# convert image to rgb format
# reduce image size to speed up the process
size = (size[0]//2, size[1]//2)
images = [Image.fromarray(image).resize(size) for image in images]
return images
except Exception as e:
print(f"Failed to read frames from video {video_path}")
return None
def _get_frames_with_idx(self, video_path, frame_idx, size):
if video_path.endswith('.jpg') or video_path.endswith('.png'):
images = []
# read all images in frame_idx
for idx in frame_idx[:-1]:
video_path_temp = video_path.replace(f'{frame_idx[-1]}.jpg', f'{idx}.jpg').replace(f'{frame_idx[-1]}.png', f'{idx}.png')
if not os.path.exists(video_path_temp):
continue
image = Image.open(video_path_temp).resize(size)
images.append(image)
image = Image.open(video_path).resize(size)
images.append(image)
return images
try:
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
num_frames = len(vr)
# remove frames that are out of range in frame_idx
frame_idx = [idx for idx in frame_idx if idx < num_frames]
images = vr.get_batch(frame_idx).asnumpy()
# https://github.com/dmlc/decord/issues/208
vr.seek(0)
# convert image to rgb format
# reduce image size to speed up the process
# size = (size[0]//2, size[1]//2)
images = [Image.fromarray(image).resize(size) for image in images]
return images
except Exception as e:
print(f"Failed to read frames from video {video_path}")
return None
def _construct_caption(self, item, video_path, visual_traces):
"""
v4->v5: add trace of mark
"""
if video_path is None and visual_traces is None:
dummy_conversations = []
dummy_conversations.append({'from': 'human', 'value': f"{self.image_placeholder}\nWhat is in this image?"})
dummy_conversations.append({'from': 'gpt', 'value': "This is a blank image."})
item['conversations'] = dummy_conversations
item['image'] = None
return item
if 'image_size' not in item:
assert '(height,width)' in item, f"image_size not in item and (height,width) not in item"
item['image_size'] = item['(height,width)'][::-1]
if isinstance(item['image_size'][0], torch.Tensor):
width, height = item['image_size'][0].item(), item['image_size'][1].item()
frame_start, frame_end = item['frame_interval'][0].item(), item['frame_interval'][1].item()
task_description = item['global_instructions'][0]
gpt_response = item['gpt_response'][0]
else:
width, height = item['image_size']
frame_start, frame_end = item['frame_interval']
task_description = item['global_instructions']
gpt_response = item['gpt_response']
gpt_response = self._process_gpt_response(gpt_response, task_description)
item['image'] = self._get_frames(video_path, frame_start, frame_end, (width, height))
if item['image'] is not None:
image_placeholder = ''.join([self.image_placeholder] * len(item['image']))
conv_user = (
f'{image_placeholder}\nWhat do you see in the first image? And what will the person do next?\n'
)
conv_gpt = gpt_response + '\n'
item['conversations'] = [
{'from': 'human', 'value': conv_user},
{'from': 'gpt', 'value': conv_gpt}
]
else:
image_placeholder = ''.join([self.image_placeholder])
conv_user = (
f'{image_placeholder}\nWhat is in this image?\n'
)
conv_gpt = "This is a blank image.\n"
item['conversations'] = [
{'from': 'human', 'value': conv_user},
{'from': 'gpt', 'value': conv_gpt}
]
return item
\ No newline at end of file
import torch
from dataclasses import dataclass, field
from magma.processing_magma import MagmaProcessor
from typing import Dict, Optional, Sequence, List
import transformers
from data.utils.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
processor: MagmaProcessor
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels, pixel_values, image_sizes = \
tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "pixel_values", "image_sizes"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.processor.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels,
batch_first=True,
padding_value=IGNORE_INDEX)
input_ids = input_ids[:, :self.processor.tokenizer.model_max_length]
labels = labels[:, :self.processor.tokenizer.model_max_length]
pixel_values = [torch.cat(pv, dim=0) for pv in pixel_values]
image_sizes = [torch.cat(isz, dim=0) for isz in image_sizes]
pixel_values_padded = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True, padding_value=0)
image_sizes_padded = torch.nn.utils.rnn.pad_sequence(image_sizes, batch_first=True, padding_value=0)
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.processor.tokenizer.pad_token_id),
pixel_values=pixel_values_padded,
image_sizes=image_sizes_padded
)
return batch
@dataclass
class DataCollatorForHFDataset(object):
"""Collate hugging face examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances]
for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels,
batch_first=True,
padding_value=IGNORE_INDEX)
input_ids = input_ids[:, :self.tokenizer.model_max_length]
labels = labels[:, :self.tokenizer.model_max_length]
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
if 'image' in instances[0] and instances[0]['image'] is not None:
images = [instance['image'] for instance in instances]
# if all(x is not None and x.shape == images[0].shape for x in images):
# batch['images'] = torch.stack(images)
# else:
batch['images'] = images
if 'add_im_loss' in instances[0]:
batch['add_im_loss'] = True
if 'max_num_crops' in instances[0]:
batch['max_num_crops'] = instances[0]['max_num_crops']
return batch
\ No newline at end of file
import json
import yaml
import torch
import random
import os
import glob
import pickle
from datasets import load_dataset
from .openx import OpenXDataItem
from tqdm import tqdm
class DataItem:
"""
Curate data items from all data sources
"""
def __init__(self, training_size=-1, local_run=False):
self.training_size = training_size
self.local_run = local_run
def _get_dataset_tag(self, data_path):
if "epic" in data_path.lower():
return "epic"
elif "open-x" in data_path or "openx" in data_path:
if 'traces' in data_path:
return "openx_magma"
else:
return "openx"
elif "sthv2" in data_path.lower():
return "sthv2"
elif "exoego4d" in data_path.lower():
return "exoego4d"
elif 'ego4d' in data_path.lower():
return "ego4d"
elif 'aitw' in data_path.lower():
return "aitw"
elif 'seeclick' in data_path.lower() and 'ocr' in data_path.lower():
return "seeclick_ocr"
elif 'seeclick' in data_path.lower():
return "seeclick"
elif 'mind2web' in data_path.lower():
return "mind2web"
elif 'vision2ui' in data_path.lower():
return "vision2ui"
elif 'llava' in data_path.lower():
return "llava"
elif 'magma' in data_path.lower():
return "magma"
elif 'sharegpt4v' in data_path.lower():
return "sharegpt4v"
else:
raise ValueError(f"Dataset tag not found for {data_path}")
def _get_items(self, data_path, image_folder=None, processor=None, conversation_lib=None):
if data_path.endswith(".json"):
list_data_dict = json.load(open(data_path, "r"))
elif data_path.endswith(".jsonl"):
list_data_dict = [json.loads(line) for line in open(data_path, "r")]
elif data_path.endswith(".pth"):
list_data_dict = torch.load(data_path, map_location="cpu")
# random.shuffle(list_data_dict)
else:
if self._get_dataset_tag(data_path) == "openx":
list_data_dict = OpenXDataItem()(data_path, image_folder, processor=processor, conversation_lib=conversation_lib, local_run=self.local_run)
elif self._get_dataset_tag(data_path) == "pixelprose":
# Load the dataset
list_data_dict = load_dataset(
data_path,
cache_dir=image_folder
)
else:
data_folder = os.path.dirname(data_path)
# get file name from data_path
data_files = data_path.split('/')[-1].split('+')
list_data_dict = []
for file in data_files:
json_path = os.path.join(data_folder, file + '.json')
list_data_dict.extend(json.load(open(json_path, "r")))
return list_data_dict
def __call__(self, data_path, processor=None, conversation_lib=None, is_eval=False):
assert data_path is not None, "Data path is not provided"
if data_path.endswith(".yaml"):
data_dict = yaml.load(open(data_path, "r"), Loader=yaml.FullLoader)
data_path_key = 'DATA_PATH' if not is_eval else 'DATA_PATH_VAL'
image_folder_key = 'IMAGE_FOLDER' if not is_eval else 'IMAGE_FOLDER_VAL'
assert len(data_dict[data_path_key]) == len(data_dict[image_folder_key]), "Data path and image folder mismatch"
items = {}
dataset_names = []
dataset_folders = []
for i, (data_path, image_folder) in enumerate(zip(data_dict[data_path_key], data_dict[image_folder_key])):
items_temp = self._get_items(data_path, image_folder, processor, conversation_lib)
dataset_tag = self._get_dataset_tag(data_path)
if dataset_tag != "openx":
# if self.training_size > 0:
# items_temp = items_temp[:self.training_size]
if dataset_tag in ['sthv2', "ego4d", "exoego4d"]:
for item in items_temp:
item['image_folder'] = image_folder
item['dataset_tag'] = dataset_tag
item['gpt_response'] = ''
item['global_instructions'] = item['annotations']
elif dataset_tag in ["openx_magma"]:
items_dict_temp = []
for item in items_temp:
items_dict_temp.append(
{
'image': item.replace('traces', 'images').replace('.pth', '.jpg'),
'trace': item,
'image_folder': image_folder,
'dataset_tag': dataset_tag
}
)
items_temp = items_dict_temp
else:
# add image_foler to each item
for item in items_temp:
item['image_folder'] = image_folder
# add dataset tag to each item
for item in items_temp:
item['dataset_tag'] = dataset_tag
if dataset_tag in items:
items[dataset_tag].extend(items_temp)
else:
items[dataset_tag] = items_temp
dataset_names.append(dataset_tag)
dataset_folders.append(image_folder)
else:
items = self._get_items(data_path)
dataset_names = None
dataset_folders = None
return items, dataset_names, dataset_folders
\ No newline at end of file
import os
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List
import pandas as pd
import torch
import deepspeed
import glob
import pandas as pd
import transformers
import tokenizers
import random
import re
import cv2
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset
import torch.distributed as dist
import collections
from PIL import Image
from io import BytesIO
from data.utils.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from magma.image_processing_magma import MagmaImageProcessor
from magma.processing_magma import MagmaProcessor
from .data_item import DataItem
from . import *
from PIL import Image, ImageFile
from PIL import ImageDraw, ImageFont
from typing import List, Optional, Union
def preprocess_multimodal(
sources: Sequence[str],
data_args: None
) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
for source in sources:
for sentence in source:
# move all DEFAULT_IMAGE_TOKEN to the beginning of the sentence
if DEFAULT_IMAGE_TOKEN in sentence['value']:
# count the number of DEFAULT_IMAGE_TOKEN in the sentence
num_image_tokens = sentence['value'].count(DEFAULT_IMAGE_TOKEN)
# remove all DEFAULT_IMAGE_TOKEN from the sentence
if data_args.mm_use_image_start_end and (DEFAULT_IM_START_TOKEN + '<image>' + DEFAULT_IM_END_TOKEN) in sentence['value']:
sentence['value'] = sentence['value'].replace(DEFAULT_IM_START_TOKEN + '<image>' + DEFAULT_IM_END_TOKEN +'\n', '').replace(DEFAULT_IM_START_TOKEN + '<image>' + DEFAULT_IM_END_TOKEN, '')
else:
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN + '\n', '').replace(DEFAULT_IMAGE_TOKEN, '')
# add num_image_tokens DEFAULT_IMAGE_TOKEN to the beginning of the sentence
sentence['value'] = (DEFAULT_IMAGE_TOKEN + '\n') * num_image_tokens + sentence['value']
if data_args.mm_use_image_start_end:
sentence['value'] = sentence['value'].replace('<image>', DEFAULT_IM_START_TOKEN + '<image>' + DEFAULT_IM_END_TOKEN)
return sources
def preprocess(
sources: Sequence[str],
processor: MagmaProcessor,
has_image: bool = False):
conversations = []
for i, source in enumerate(sources):
convs = copy.deepcopy(source)
for elem in convs:
elem['role'] = 'user' if elem['from'] in ['human', 'user'] else 'assistant'
elem['content'] = elem['value']
convs = [
{
"role": "system",
"content": "You are agent that can see, talk and act.",
},
] + convs
text = processor.tokenizer.apply_chat_template(
convs,
tokenize=False,
add_generation_prompt=False
)
conversations.append(text)
# NOTE: this is only for QWen
# get the sep1 and sep2
dummy_convs = [
{
"role": "system",
"content": "You are agent that can see, talk and act.",
},
{
"role": "user",
"content": ""
},
{
"role": "assistant",
"content": ""
}
]
dummy_text = processor.tokenizer.apply_chat_template(
dummy_convs,
tokenize=False,
add_generation_prompt=False,
)
empty_token_lengh = len(processor.tokenizer("").input_ids)
bos_token = processor.tokenizer.bos_token
eos_token = processor.tokenizer.eos_token
if 'phi' in processor.tokenizer.name_or_path.lower():
eos_token = '<|end|>\n'
elif 'qwen2-' in processor.tokenizer.name_or_path.lower():
bos_token = '<|im_start|>'
eos_token = '<|im_end|>\n'
segments = dummy_text.split(eos_token)[:-1]
sep1, sep2 = segments[-2], segments[-1]
if bos_token:
sep1 = sep1.replace(bos_token, '')
tokenizer = processor.tokenizer
# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
for k, (conversation, target) in enumerate(zip(conversations, targets)):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
if 'phi' in processor.tokenizer.name_or_path.lower():
# Phi-3 has an pad_token at the end
total_len = total_len + 1
conversation_sys = conversation.split(sep1)[0]
conversation = conversation[len(conversation_sys):]
rounds = conversation.split(sep1)[1:]
cur_len = len(tokenizer(conversation_sys).input_ids)
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep2)
if len(parts) != 2:
break
parts[0] = sep1 + parts[0] + sep2
rou = sep1 + rou
# NOTE: the reason to minus 1 is because tokenizer will give a start token, e.g., 128000 for llama3
round_len = len(tokenizer(rou).input_ids) - empty_token_lengh
instruction_len = len(tokenizer(parts[0]).input_ids) - empty_token_lengh
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
print(
conversations[k]
)
return dict(
input_ids=input_ids,
labels=targets,
)
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self,
processor: MagmaProcessor,
data_items: Dict,
dataset_names: List[str],
dataset_folders: List[str],
data_args: None):
super(LazySupervisedDataset, self).__init__()
self.processor = processor
self.data_args = data_args
self.data_items = []
self.conv_constructor = {}
if dataset_names is not None:
for dataset_name, dataset_folder in zip(dataset_names, dataset_folders):
if dataset_name in ['sharegpt4v', 'aitw', 'mind2web']:
self.data_items.extend(data_items[dataset_name])
elif dataset_name in ['seeclick', 'vision2ui', 'seeclick_ocr']:
self.conv_constructor[dataset_name] = eval(dataset_name)(
mm_use_trace_start_end=data_args.mm_use_trace_start_end,
mm_use_trace_speed=data_args.mm_use_trace_speed,
mm_use_image_start_end=data_args.mm_use_image_start_end,
mm_use_image_history=data_args.mm_use_image_history,
mm_use_som_tom=data_args.mm_use_som_tom,
mm_use_som_tom_orig_img=data_args.mm_use_som_tom_orig_img,
remove_static_trace_pts=data_args.remove_static_trace_pts,
spatial_quant_size=data_args.spatial_quant_size,
dataset_folder=dataset_folder,
show_trace=data_args.show_trace,
task=data_args.task,
training_size=data_args.training_size,
tokenizer=processor.tokenizer,
)
final_items = self.conv_constructor[dataset_name].filter_items(data_items[dataset_name])
self.data_items.extend(final_items)
else:
self.conv_constructor[dataset_name] = eval(dataset_name)(
mm_use_trace_start_end=data_args.mm_use_trace_start_end,
mm_use_trace_speed=data_args.mm_use_trace_speed,
mm_use_image_start_end=data_args.mm_use_image_start_end,
mm_use_image_history=data_args.mm_use_image_history,
mm_use_som_tom=data_args.mm_use_som_tom,
mm_use_som_tom_orig_img=data_args.mm_use_som_tom_orig_img,
remove_static_trace_pts=data_args.remove_static_trace_pts,
spatial_quant_size=data_args.spatial_quant_size,
dataset_folder=dataset_folder,
show_trace=data_args.show_trace,
task=data_args.task,
training_size=data_args.training_size,
tokenizer=processor.tokenizer,
)
final_items = self.conv_constructor[dataset_name].filter_items(data_items[dataset_name])
self.data_items.extend(final_items)
self.action_placeholder_token_id = self.processor.tokenizer.convert_tokens_to_ids('<action>')
def __len__(self):
return len(self.data_items)
@property
def lengths(self):
length_list = []
for sample in self.data_items:
img_tokens = 128 if ('image' in sample and sample['image'] is not None) else 0
length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.data_items:
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
cur_len = cur_len if ('image' in sample and sample['image'] is not None) else -cur_len
length_list.append(cur_len)
return length_list
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
item = copy.deepcopy(self.data_items[i])
if 'video' in item and item['video'][0] is not None:
assert item['image_folder'] is not None or self.data_args.image_folder is not None, "image_folder is not provided"
image_folder = self.data_args.image_folder if self.data_args.image_folder is not None else item['image_folder']
if item['dataset_tag'] in ['sthv2', 'ego4d', 'exoego4d']:
visual_trace_path = os.path.join(image_folder, item['trace'])
if os.path.exists(visual_trace_path):
try:
visual_traces = torch.load(visual_trace_path, map_location='cpu')
video_path = os.path.join(image_folder, item['video'].replace('/home/tanreuben/vlp_datasets/', ''))
item.update(visual_traces)
except Exception as e:
print(f"Error loading: {visual_trace_path}")
visual_traces = None
video_path = None
else:
print(f"Error: {visual_trace_path} not found")
visual_traces = None
video_path = None
item = self.conv_constructor[item['dataset_tag']](item=item, video_path=video_path, visual_traces=visual_traces)
else:
item['video'][0] = item['video'][0].replace('/mnt/data/video_datasets_visual_traces/YouCook2/', '')
video_path = os.path.join(image_folder, item['video'][0])
frame_start, frame_end = item['frame_interval'][0].item(), item['frame_interval'][1].item()
video_name = os.path.basename(video_path).split('.')[0]
if 'youcook2' in video_path.lower():
visual_trace_path = os.path.join(image_folder, 'all_detected_visual_traces_30fps', f'{video_name}_trace_{frame_start:09d}_{frame_end:09d}.pth')
else:
visual_trace_path = os.path.join(image_folder, 'visual_trace' if 'epic' in image_folder else 'visual_traces', video_name, f'trace_{frame_start:09d}_{frame_end:09d}.pth')
if os.path.exists(visual_trace_path):
visual_traces = torch.load(visual_trace_path, map_location='cpu')
else:
visual_traces = None
item = self.conv_constructor[item['dataset_tag']](item=item, video_path=video_path, visual_traces=visual_traces)
image = item['image']
num_crops = item['num_crops']
# if image is not a PIL image
if image is None:
base_img_size = self.processor.image_processor.base_img_size
image = Image.new('RGB', (base_img_size, base_img_size), (0, 0, 0))
item['image'] = image
num_crops = 1
image_pt = self.processor.image_processor(image, num_crops=num_crops, return_tensors='pt')
images = collections.defaultdict(list)
for key, val in image_pt.items():
images[key].append(val)
texts = [item["conversations"]]
elif 'image' in item and item['image'] is not None:
import pdb; pdb.set_trace()
# cope with multiple images
image_folder = item['image_folder']
image_files = item['image']
if isinstance(image_files, str):
image_files = [image_files]
image_files = [image_path.replace("ming2web_images", "mind2web_images") for image_path in image_files]
# image_files = image_files*2
# item["conversations"][0]['value'] = item["conversations"][0]['value'] + '\n' + DEFAULT_IMAGE_TOKEN
images = collections.defaultdict(list)
for image_file in image_files:
image_file = image_file[1:] if image_file.startswith('/') else image_file
image_path = os.path.join(image_folder, image_file)
try:
if "trace" in self.data_items[i]:
trace_file = self.data_items[i]["trace"]
trace_path = os.path.join(image_folder, trace_file)
if os.path.exists(trace_path):
visual_traces = torch.load(trace_path, map_location='cpu')
item.update(visual_traces)
else:
visual_traces = None
video_path = image_path
item = self.conv_constructor[item['dataset_tag']](item=item, video_path=image_path, visual_traces=visual_traces)
image = item['image_data']
num_crops = item['num_crops']
if image is None:
base_img_size = self.processor.image_processor.base_img_size
image = Image.new('RGB', (base_img_size, base_img_size), (0, 0, 0))
num_crops = 1
# NOTE: override num_crops for robotics dataset
image = self.processor.image_processor(image, num_crops=num_crops, return_tensors='pt')
elif 'ocrs' in self.data_items[i]:
item = self.conv_constructor[item['dataset_tag']](item=item)
image = self.processor.image_processor(image, return_tensors='pt')
else:
# regular image sft dataset
image = Image.open(image_path).convert('RGB')
# if item['dataset_tag'] in ['seeclick', 'vision2ui']:
# image = self.processor.image_processor(image, num_crops=9, return_tensors='pt')
# else:
image = self.processor.image_processor(image, return_tensors='pt')
for key, val in image.items():
images[key].append(val)
except Exception as e:
print(f"Error: {e}")
base_img_size = self.processor.image_processor.base_img_size
image = Image.new('RGB', (base_img_size, base_img_size), (0, 0, 0))
image = self.processor.image_processor(image, num_crops=1, return_tensors='pt')
for key, val in image.items():
images[key].append(val)
texts = preprocess_multimodal(
copy.deepcopy([item["conversations"]]),
self.data_args)
else:
images = collections.defaultdict(list)
# image does not exist in the data, but the model is multimodal
base_img_size = self.processor.image_processor.base_img_size
image = Image.new('RGB', (base_img_size, base_img_size), (0, 0, 0))
image = self.processor.image_processor(image, num_crops=1, return_tensors='pt')
for key, val in image.items():
images[key].append(val)
item["conversations"][0]['value'] = DEFAULT_IMAGE_TOKEN + '\n' + item["conversations"][0]['value']
if self.data_args.mm_use_image_start_end:
item["conversations"][0]['value'] = item["conversations"][0]['value'].replace('<image>', DEFAULT_IM_START_TOKEN + '<image>' + DEFAULT_IM_END_TOKEN)
texts = [item["conversations"]]
data_dict = preprocess(
texts,
self.processor,
has_image=('image' in item and item['image'] is not None)
)
if self.action_placeholder_token_id in data_dict['input_ids']:
assert (data_dict['input_ids'] == self.action_placeholder_token_id).sum() == 7, "action token length should be 7 in input_ids"
assert self.action_placeholder_token_id in data_dict['labels'], "action token should be also in labels"
assert (data_dict['labels'] == self.action_placeholder_token_id).sum() == 7, "action token length should be 7 in labels"
# replace the action token with the actual action token item['action_token_ids']
action_token_ids = torch.tensor(item['action_token_ids'], dtype=torch.long)[None,:]
data_dict['input_ids'][data_dict['input_ids'] == self.action_placeholder_token_id] = action_token_ids
data_dict['labels'][data_dict['labels'] == self.action_placeholder_token_id] = action_token_ids
if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
data_dict.update(images)
data_dict.update(
{
"dataset_name": item["dataset_tag"],
"item_id": i
}
)
del item
return data_dict
# Custom wrapper to combine Dataset and IterableDataset without loading IterableDataset in memory
class CombinedDataset(Dataset):
def __init__(self, dataset, iterable_dataset, local_run=False, seed=7):
self.dataset_len = []
if dataset is not None:
self.dataset_len.append(len(dataset)) # Length of the Dataset
if dist.is_initialized():
sampler = DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=True,
seed=seed,
drop_last=False,
)
else:
sampler = None
self.iterable_dataset_a = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0 if local_run else 8, pin_memory=False) # DataLoader for the Dataset
self.iterable_iter_a = iter(self.iterable_dataset_a)
else:
self.iterable_dataset_a = None
self.iterable_iter_a = None
self.dataset_len.append(0)
if iterable_dataset is not None:
self.dataset_len.append(len(iterable_dataset)) # Length of the IterableDataset
self.iterable_dataset_b = iterable_dataset
self.iterable_iter_b = iter(self.iterable_dataset_b) # Iterator for the IterableDataset
else:
self.iterable_dataset_b = None
self.iterable_iter_b = None
self.dataset_len.append(0)
self.sampling_ratios = [float(item)/sum(self.dataset_len) for item in self.dataset_len]
print(f"total training data size: {sum(self.dataset_len)}")
print(f"sampling ratios: {self.sampling_ratios}")
def __len__(self):
# Length can be the maximum of both or some other logic
return sum(self.dataset_len)
def __getitem__(self, index):
# according to the sampling ratio, choose which dataset to sample
dataset_choice = random.choices([0, 1], self.sampling_ratios)[0]
if dataset_choice == 0:
# Fetch a sample from the IterableDataset using its iterator
try:
iterable_sample_a = next(self.iterable_iter_a)
except StopIteration:
# Reinitialize the iterator if it exhausts
self.iterable_iter_a = iter(self.iterable_dataset_a)
iterable_sample_a = next(self.iterable_iter_a)
iterable_sample_a['input_ids'] = iterable_sample_a['input_ids'][0]
iterable_sample_a['labels'] = iterable_sample_a['labels'][0]
iterable_sample_a['pixel_values'] = [item[0] for item in iterable_sample_a['pixel_values']]
iterable_sample_a['image_sizes'] = [item[0] for item in iterable_sample_a['image_sizes']]
return iterable_sample_a
else:
# Fetch a sample from the IterableDataset using its iterator
try:
iterable_sample_b = next(self.iterable_iter_b)
except StopIteration:
# Reinitialize the iterator if it exhausts
self.iterable_iter_b = iter(self.iterable_dataset_b)
iterable_sample_b = next(self.iterable_iter_b)
# print(f"oxe-rank-{rank}: {iterable_sample_b['dataset_name']}")
# Return a combined sample (modify based on your requirement)
return iterable_sample_b
def build_joint_dataset(
data_path: str,
processor: MagmaProcessor,
data_args: None,
is_eval: bool = False
) -> torch.utils.data.ConcatDataset:
data_items, dataset_names, dataset_folders = DataItem(training_size=data_args.training_size, local_run=data_args.local_run)(data_path, processor, None, is_eval=is_eval)
# pop out open-x dataset
openx_dataset = None
if 'openx' in data_items:
openx_dataset = data_items.pop('openx')
_ = dataset_folders.pop(dataset_names.index('openx'))
_ = dataset_names.pop(dataset_names.index('openx'))
lazy_dataset = None
if len(data_items) > 0:
lazy_dataset = LazySupervisedDataset(processor, data_items, dataset_names, dataset_folders, data_args)
# concatenate openx dataset and lazy_dataset
return CombinedDataset(lazy_dataset, openx_dataset, local_run=data_args.local_run)
else:
return LazySupervisedDataset(processor, data_items, dataset_names, dataset_folders, data_args)
from .data_utils import Ego4d as ego4d
\ No newline at end of file
import torch
import torchvision
import re
import cv2
import numpy as np
import os
import yaml
from tqdm import tqdm
from PIL import Image
from data.utils.visual_trace import visual_trace
from data.utils.som_tom import som_prompting, tom_prompting
from data.conversations import Constructor
import logging
logger = logging.getLogger(__name__)
class Ego4d(Constructor):
def __init__(self, **kwargs):
super(Ego4d, self).__init__(**kwargs)
# load settings from settings.yaml file
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'settings.yaml'), 'r') as file:
self.settings = yaml.safe_load(file)
self.spatial_quant_size = kwargs.get('spatial_quant_size', 256) # this is also used for open-x
self.num_clusters = self.settings['trace_processor']['num_clusters']
self.root_dir = kwargs.get('dataset_folder', None)
self.task = kwargs.get('task', 'agent')
self.use_som_tom = kwargs.get('mm_use_som_tom', True)
if kwargs.get('training_size', 'default') == 'default':
self.training_size = self.settings['training'].get('size', -1)
else:
self.training_size = kwargs.get('training_size', -1)
# convert M to 1000000, e.g, 10M means 10,000,000
if 'M' in self.training_size:
self.training_size = int(float(self.training_size.replace('M', '')) * 1000000)
else:
self.training_size = int(self.training_size)
self.filtered_verb = [
'converse',
'walk',
'laugh',
'stand',
'move around',
'looks around',
]
def __call__(self, **kwargs):
return super()._construct_conv(**kwargs)
def filter_items(self, items):
"""
Filter invalid items
"""
filtered_items = []
print("Filtering items")
for item in tqdm(items):
global_instruction = item['global_instructions']
if len(global_instruction) == 0:
continue
# check if global_instruction contain any word in self.filtered_verb
# if so, skip this item
if any(verb in global_instruction for verb in self.filtered_verb):
continue
seg_name = item['video'].split('/')[-1]
start_str, end_str = seg_name.split('___')[0:2]
start_time = float(start_str.split('_')[-1])
end_time = float(end_str.split('_')[-1])
if (end_time-start_time) < 1:
continue
filtered_items.append(item)
if self.training_size > 0 and self.training_size < len(filtered_items):
# sample uniformly self.training_size samples from the filtered items
filtered_items = filtered_items[::(len(filtered_items)//self.training_size)]
print(f"Keep {len(filtered_items)} items from {len(items)} items")
return filtered_items
\ No newline at end of file
# tracker settings
tracker:
backward_tracking: true
ckpt_path: ./checkpoints/cotracker2.pth
grid_query_frame: 0
grid_size: 32
save_dir: ./
# sft settings
trace_processor:
num_clusters: 5
postive_factor_threshold: 0.5 # this will times the max value of the trace to get the threshold
postive_speed_threshold: 2 # this is the speed threshold for the positive trace
trace_planner:
quant_size: 200
skip_frames: 16
step_to_predict: 16 # use same setting as COIN since the videos have 30fps
step_rightmost_ratio: 0.5 # the ratio of the rightmost point to set as the start frame
training:
size: 1_000_000
\ No newline at end of file
from .data_utils import EpicKitchen as epic
\ No newline at end of file
import torch
import torchvision
import re
import cv2
import numpy as np
import os
import yaml
from PIL import Image
from data.conversations import Constructor
class EpicKitchen(Constructor):
def __init__(self, **kwargs):
super(EpicKitchen, self).__init__(**kwargs)
# load settings from settings.yaml file
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'settings.yaml'), 'r') as file:
self.settings = yaml.safe_load(file)
self.spatial_quant_size = kwargs.get('spatial_quant_size', 256) # this is also used for open-x
self.num_clusters = self.settings['trace_processor']['num_clusters']
self.root_dir = kwargs.get('dataset_folder', None)
self.task = kwargs.get('task', 'agent')
self.use_som_tom = kwargs.get('mm_use_som_tom', True)
def __call__(self, **kwargs):
if self.task == "captioner":
return super()._construct_caption(**kwargs)
else:
return super()._construct_conv(**kwargs)
def filter_items(self, items):
"""
filter out items that are not suitable for conversation construction
"""
filtered_items = []
for item in items:
# remove closeup videos
if 'closeup' in item['gpt_response'][0] or \
'close-up' in item['gpt_response'][0] or \
'close up' in item['gpt_response'][0] or \
'What you should do next' not in item['gpt_response'][0]:
continue
# item['gpt_response'][0] = item['gpt_response'][0].replace('blue', 'yellow')
filtered_items.append(item)
print(f"Filtered {len(items) - len(filtered_items)} items from {len(items)} items")
return filtered_items
\ No newline at end of file
# tracker settings
tracker:
ckpt_path: "./checkpoints/cotracker2.pth"
grid_size: 32
grid_query_frame: 0
backward_tracking: True
save_dir: "./"
# sft settings
trace_processor:
num_clusters: 5
postive_factor_threshold: 0.5 # this will times the max value of the trace to get the threshold
postive_speed_threshold: 1 # this is the speed threshold for the positive trace
trace_planner:
step_rightmost_ratio: 0.5 # the ratio of the rightmost point to set as the start frame
\ No newline at end of file
from .data_utils import LlaVA as llava
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment