Commit 0063a668 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
import torch
import torchvision
import re
import cv2
import numpy as np
import os
import yaml
from PIL import Image
from data.utils.visual_trace import visual_trace
from data.utils.som_tom import som_prompting, tom_prompting
from data.conversations import Constructor
from data.openx.action_tokenizer import ActionTokenizer
class OpenXMagma(Constructor):
def __init__(self, **kwargs):
super(OpenXMagma, self).__init__(**kwargs)
# load settings from settings.yaml file
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'settings.yaml'), 'r') as file:
self.settings = yaml.safe_load(file)
self.spatial_quant_size = kwargs.get('spatial_quant_size', 256) # this is also used for open-x
self.num_clusters = self.settings['trace_processor']['num_clusters']
self.root_dir = kwargs.get('dataset_folder', None)
self.task = kwargs.get('task', 'agent')
self.use_som_tom = kwargs.get('mm_use_som_tom', True)
tokenizer = kwargs.get('tokenizer', None)
assert tokenizer, "Tokenizer is not provided"
if self.mm_use_image_start_end:
self.image_placeholder = '<image_start><image><image_end>\n'
else:
self.image_placeholder = '<image>\n'
self.action_tokenizer = ActionTokenizer(tokenizer)
self.trace_width = 256
self.trace_height = 256
def __call__(self, item, video_path, visual_traces, width=512, height=512):
item['num_crops'] = 1
if video_path is None and visual_trace is None:
dummy_conversations = []
dummy_conversations.append({'from': 'human', 'value': f"{self.image_placeholder}\nWhat is in this image?"})
dummy_conversations.append({'from': 'gpt', 'value': "This is a blank image."})
item['conversations'] = dummy_conversations
item['image_data'] = None
return item
frame_start, frame_end = item['frame_index'], item['frame_index'] + 16
task_description = item['lang']
gpt_response = task_description
if self.mm_use_image_history:
# randomly sample at most 7 unique indices in range [0, frame_start) with probability 0.3
# if torch.rand(1).item() < 0.5:
# frame_idx = torch.randperm(frame_start)[:7].sort().values.tolist() + [frame_start]
# else:
frame_idx = [frame_start]
else:
frame_idx = [frame_start]
item['image_data'] = self._get_frames_with_idx(video_path, frame_idx, (width, height))
# conversation 1: Q: to do the task, what should be the next action? A: next action
image_placeholder = ''.join([self.image_placeholder]*len(item['image_data']))
item['conversations'] = [
{"from": "human", "value": f"{image_placeholder}\nWhat action should the robot take to {gpt_response}?"},
{"from": "gpt", "value": ''.join(['<action>']*7)}, # placeholder for action tokens
]
action = visual_traces['action']
action_token_ids = self.action_tokenizer.encode_actions_to_token_ids(action)
item['action_token_ids'] = action_token_ids
# conversation 2: Q: what is the robot doing? A: task description
# conv_user, conv_gpt, gpt_response_todo = self._construct_conv_semantic(item, gpt_response)
# conversations.append({'from': 'human', 'value': conv_user})
# conversations.append({'from': 'gpt', 'value': conv_gpt})
# item['image_data'].append(self._get_frame(video_path, frame_start, 0, (width, height)))
if not self.use_som_tom:
return item
if visual_traces is None:
return item
if item['dataset_name'].decode('utf-8') in ["berkeley_cable_routing", "kuka"]:
return item
visual_traces['pred_tracks'], visual_traces['pred_visibility'] = visual_traces['trace_info']
if width != self.trace_width:
visual_traces['pred_tracks'][...,0] = visual_traces['pred_tracks'][...,0] * width // self.trace_width
if height != self.trace_height:
visual_traces['pred_tracks'][...,1] = visual_traces['pred_tracks'][...,1] * height // self.trace_height
if len(visual_traces['pred_tracks'].shape) == 3:
visual_traces['pred_tracks'] = visual_traces['pred_tracks'][None]
if len(visual_traces['pred_visibility'].shape) == 2:
visual_traces['pred_visibility'] = visual_traces['pred_visibility'][None]
frame_pos = 0
pred_tracks = visual_traces['pred_tracks'][:, frame_pos:]
pred_visibility = visual_traces['pred_visibility'][:, frame_pos:]
step_to_predict = pred_tracks.size(1)
if step_to_predict == 0:
return item
pred_tracks_history = visual_traces['pred_tracks'][:, :max(1, frame_pos+1)]
pred_visibility_history = visual_traces['pred_visibility'][:, :max(1, frame_pos+1)]
# only keep points that are visible at at least half steps
valid_idx = pred_visibility[0].sum(0) > 0.5*pred_tracks.shape[1]
if valid_idx.sum() <= 1:
image = self._get_frame(video_path, frame_start, 0, (width, height))
conv_user, conv_gpt, image = self._construct_conv_som(item, image, visual_traces, 0)
item['conversations'].append({'from': 'human', 'value': conv_user})
item['conversations'].append({'from': 'gpt', 'value': conv_gpt})
item['image_data'].append(image)
return item
pred_tracks = pred_tracks[:, :, valid_idx]
pred_visibility = pred_visibility[:, :, valid_idx]
pred_tracks_history = pred_tracks_history[:, :, valid_idx]
pred_visibility_history = pred_visibility_history[:, :, valid_idx]
# calculate the trajectory lenght for pred_tracks
pred_tracks_length = self.trace.visual_trace_length(pred_tracks, pred_visibility, (1, 1)).squeeze(0)
# if 80% of the pred_tracks_length is larger than 2, then there is camera motion
camera_motion = (pred_tracks_length > 1).sum() > 0.8*pred_tracks_length.size(0)
camera_motion = True if item['dataset_tag'] in ['ego4d', 'epic'] else camera_motion
start_pos = pred_tracks[:, 0][0]
reference_pts_np = start_pos.cpu().numpy().reshape(-1, 2)
if camera_motion:
# remove camera motion using homography transformation
try:
future_pts_transformed = []
for k in range(1, pred_tracks.shape[1]):
future_pts = pred_tracks[:, k][0]
future_pts_np = future_pts.cpu().numpy().reshape(-1, 2)
try:
(H, status) = cv2.findHomography(future_pts_np, reference_pts_np, cv2.RANSAC, 4.0)
except Exception as e:
continue
future_pts_np_transformed = cv2.perspectiveTransform(future_pts_np.reshape(1, -1, 2), H).reshape(-1, 2)
future_pts_transformed_k = torch.tensor(future_pts_np_transformed, dtype=torch.float32)
future_pts_transformed.append(future_pts_transformed_k)
pred_tracks = torch.stack([start_pos] + future_pts_transformed, dim=0).unsqueeze(0)
except Exception as e:
pass
if pred_tracks_history.size(1) > 0:
try:
history_pts_transformed = []
for k in range(0, pred_tracks_history.shape[1]):
history_pts = pred_tracks_history[:, k][0]
history_pts_np = history_pts.cpu().numpy().reshape(-1, 2)
try:
(H, status) = cv2.findHomography(history_pts_np, reference_pts_np, cv2.RANSAC, 4.0)
except Exception as e:
continue
history_pts_np_transformed = cv2.perspectiveTransform(history_pts_np.reshape(1, -1, 2), H).reshape(-1, 2)
history_pts_transformed_k = torch.tensor(history_pts_np_transformed, dtype=torch.float32)
history_pts_transformed.append(history_pts_transformed_k)
pred_tracks_history = torch.stack(history_pts_transformed, dim=0).unsqueeze(0)
except Exception as e:
pass
# step 2: find positive traces and negative traces
track_length = self.trace.visual_trace_length(pred_tracks, pred_visibility, (1, 1)).squeeze(0)
threshold = 1 # max(track_length.max(), 2) * self.settings['trace_processor']['postive_factor_threshold']
# video is almost static
if (track_length > threshold).sum() <= 1:
image = self._get_frame(video_path, frame_start, 0, (width, height))
conv_user, conv_gpt, image = self._construct_conv_som(item, image, visual_traces, 0)
item['conversations'].append({'from': 'human', 'value': conv_user})
item['conversations'].append({'from': 'gpt', 'value': conv_gpt})
item['image_data'].append(image)
return item
else:
# find the positive traces and negative traces
pos_tracks = pred_tracks[:, :, track_length > threshold]
pos_visibility = pred_visibility[:, :, track_length > threshold]
pos_tracks_history = pred_tracks_history[:, :, track_length > threshold]
pos_visibility_history = pred_visibility_history[:, :, track_length > threshold]
neg_tracks = pred_tracks[:, :, track_length <= threshold]
neg_tracks_history = pred_tracks_history[:, :, track_length <= threshold]
# clustering for positive traces
# randome sample a number between 2 and self.num_clusters
num_clusters_pos = torch.randint(2, 5, (1,)).item()
pos_sampled_ids = self.trace.cluster_traces_kmeans(pos_tracks, n_clusters=num_clusters_pos, positive=True)
if pos_sampled_ids is None:
image = self._get_frame(video_path, frame_start, 0, (width, height))
conv_user, conv_gpt, image = self._construct_conv_som(item, image, visual_traces, 0)
item['conversations'].append({'from': 'human', 'value': conv_user})
item['conversations'].append({'from': 'gpt', 'value': conv_gpt})
item['image_data'].append(image)
return item
pos_tracks = pos_tracks[:, :, pos_sampled_ids.bool()]
pos_visibility = pos_visibility[:, :, pos_sampled_ids.bool()]
pos_tracks_history = pos_tracks_history[:, :, pos_sampled_ids.bool()]
pos_visibility_history = pos_visibility_history[:, :, pos_sampled_ids.bool()]
# clustering for negative traces
num_clusters_neg = torch.randint(4, 10, (1,)).item()
neg_sampled_ids = self.trace.cluster_traces_kmeans(neg_tracks, n_clusters=num_clusters_neg)
if neg_sampled_ids is None:
image = self._get_frame(video_path, frame_start, 0, (width, height))
conv_user, conv_gpt, image = self._construct_conv_som(item, image, visual_traces, 0)
item['conversations'].append({'from': 'human', 'value': conv_user})
item['conversations'].append({'from': 'gpt', 'value': conv_gpt})
item['image_data'].append(image)
return item
neg_tracks = neg_tracks[:, :, neg_sampled_ids.bool()]
neg_tracks_history = neg_tracks_history[:, :, neg_sampled_ids.bool()]
image = self._get_frame(video_path, frame_start, frame_pos, (width, height))
if image is None:
image = self._get_frame(video_path, frame_start, 0, (width, height))
conv_user, conv_gpt, image = self._construct_conv_som(item, image, visual_traces, 0)
item['conversations'].append({'from': 'human', 'value': conv_user})
item['conversations'].append({'from': 'gpt', 'value': conv_gpt})
item['image_data'].append(image)
return item
# we have two choices: a) use visual prompting and b) use textual prompting
if self.settings['som']['format'] == "visual":
# image = tom_prompting(self.trace, image, pos_tracks_history, neg_tracks_history, draw_som_positive=False, draw_som_negative=False)
image, pos_traces_to_mark, neg_traces_to_mark, pos_mark_ids, neg_mark_ids, all_idx = \
som_prompting(image, pos_tracks, neg_tracks, draw_som_positive=True, draw_som_negative=True)
mark_ids = sorted([key for key in pos_traces_to_mark.keys()] + [key for key in neg_traces_to_mark.keys()])
else:
# image = tom_prompting(self.trace, image, pos_tracks_history, neg_tracks_history, draw_som_positive=False, draw_som_negative=False)
image, pos_traces_to_mark, neg_traces_to_mark, pos_mark_ids, neg_mark_ids, all_idx = \
som_prompting(image, pos_tracks, neg_tracks, draw_som_positive=False, draw_som_negative=False)
# aggregate the starting points of the traces from pos_trace_to_mark and neg_trace_to_mark
traces_to_mark = {**pos_traces_to_mark, **neg_traces_to_mark}
traces_to_mark = dict(sorted(traces_to_mark.items()))
mark_positions = {key: (self.spatial_quant_size*val[0][0]/torch.tensor([width, height])).int().tolist() for key, val in traces_to_mark.items()}
# turn mark_positions to str
# mark_ids = ', '.join([f"Mark {key} at [{float(val[0])/self.spatial_quant_size:.2f},{float(val[1])/self.spatial_quant_size:.2f}]" for key, val in mark_positions.items()])
mark_ids = ', '.join([f"Mark {key} at {val}" for key, val in mark_positions.items()])
# visualize the traces
if self.show_trace:
import pdb; pdb.set_trace()
images = [image] * pos_tracks.shape[1]
video = torch.stack([torchvision.transforms.ToTensor()(img) for img in images])[None].float()*255
self.trace.visualizer.save_dir = "./release/robotics"
_ = self.trace.visualize(video, pos_tracks, pos_visibility, filename=f"{item['trace'].replace('/', '_').replace('.pth', '')}", mode="rainbow")
pos_traces_to_mark = dict(sorted(pos_traces_to_mark.items()))
mark_trace_history = ''
mark_trace_future = ''
valid_marks = {}
speeds = {}
for key, val in pos_traces_to_mark.items():
# random select a frame position but not the last frame
# frame_pos = torch.randint(0, trace.size(0)-1, (1,)).item()
trace = val[0]
trace[:, 0] = self.spatial_quant_size * trace[:, 0] / width
trace[:, 1] = self.spatial_quant_size * trace[:, 1] / height
trace_temp = trace.clone()
# remove (almost) static points
trace_temp = self.trace.remove_close_points_tensor(trace_temp, 2)
# remove invisible points
trace_temp = trace_temp[(trace_temp > 0).sum(1) == 2]
# if trace_temp.size(0) <= step_to_predict // 4:
# continue
# calulate motion speed
# if trace_temp.size(0) < step_to_predict:
# trace_temp = torch.cat([trace_temp, trace_temp[-1].repeat(step_to_predict - trace_temp.size(0), 1)], dim=0)
# elif trace_temp.size(0) > step_to_predict:
# trace_temp = trace_temp[:step_to_predict]
# calcualte speed
speed = torch.norm(trace_temp[1:] - trace_temp[:-1], dim=1).mean()
if torch.isnan(speed):
continue
speeds[key] = speed.item()
if speed < self.settings['trace_processor']['postive_speed_threshold']:
continue
# trace_history = trace[0]
# val_str_history = '[' + ','.join([f'[{x[0]},{x[1]}]' for x in trace_history.tolist()]) + ']'
# mark_trace_history += f'\"Mark {key}\": \"{val_str_history}\"\n'
# round trace_temp
if self.remove_static_trace_pts:
valid_marks[key] = trace_temp.int()
else:
valid_marks[key] = trace.int()
# NOTE: there was a bug here
# val_str_future = '[' + ','.join([f'[{float(x[0])/self.spatial_quant_size:.2f},{float(x[1])/self.spatial_quant_size:.2f}]' for x in valid_marks[key][1:].tolist()]) + ']'
val_str_future = '[' + ','.join([f'[{x[0]},{x[1]}]' for x in valid_marks[key][1:].tolist()]) + ']'
mark_trace_future += f'\"Mark {key}\": \"{val_str_future}\"\n\n'
if len(mark_trace_future) > 0:
num_future_steps = [val.shape[0]-1 for val in valid_marks.values()]
step_to_predict = max(num_future_steps)
if self.mm_use_trace_speed:
# calculate the average speed of the marks
avg_speed = int(sum(speeds.values()) / len(speeds))
conv_user = (
f"{self.image_placeholder}\nThe image is split into {self.spatial_quant_size}x{self.spatial_quant_size} grids, and labeled with numeric marks: {mark_ids}.\n"
f"The robot is doing: {gpt_response}. To finish the task, how to move the numerical marks in the image with speed {avg_speed} for the next {step_to_predict} steps?\n"
)
else:
conv_user = (
f"{self.image_placeholder}\nThe image is split into {self.spatial_quant_size}x{self.spatial_quant_size} grids, and labeled with numeric marks: {mark_ids}.\n"
f"The robot is doing: {gpt_response}. To finish the task, how to move the numerical marks in the image for the next {step_to_predict} steps?\n"
)
# formmated_val = ', '.join([f"Mark {key} at [{float(val[0][0].item())/self.spatial_quant_size:.2f},{float(val[0][1].item())/self.spatial_quant_size:.2f}]" for key, val in valid_marks.items()])
formmated_val = ', '.join([f"Mark {key} at [{val[0][0].item()},{val[0][1].item()}]" for key, val in valid_marks.items()])
if self.mm_use_trace_start_end:
mark_trace_future = f'<trace_start>{mark_trace_future}<trace_end>'
conv_gpt = f"{formmated_val} should be moved, and their future positions are:\n\n{mark_trace_future}"
item['conversations'].append({'from': 'human', 'value': conv_user})
item['conversations'].append({'from': 'gpt', 'value': conv_gpt})
item['image_data'].append(image)
else:
for key, val in neg_traces_to_mark.items():
trace = val[0]
trace[:, 0] = self.spatial_quant_size * trace[:, 0] / width
trace[:, 1] = self.spatial_quant_size * trace[:, 1] / height
conv_user, conv_gpt, image = self._construct_conv_som(item, image, visual_traces, frame_pos, pos_traces_to_mark, neg_traces_to_mark, normalize=False)
item['conversations'].append({'from': 'human', 'value': conv_user})
item['conversations'].append({'from': 'gpt', 'value': conv_gpt})
item['image_data'].append(image)
import pdb; pdb.set_trace()
return item
def filter_items(self, items):
"""
Filter invalid items
"""
return items
\ No newline at end of file
# tracker settings
tracker:
ckpt_path: "./checkpoints/cotracker2.pth"
grid_size: 32
grid_query_frame: 0
backward_tracking: True
save_dir: "./"
som:
format: 'visual'
# sft settings
trace_processor:
num_clusters: 5
postive_factor_threshold: 0.5 # this will times the max value of the trace to get the threshold
postive_speed_threshold: 1 # this is the speed threshold for the positive trace
trace_planner:
step_rightmost_ratio: 0.5 # the ratio of the rightmost point to set as the start frame
\ No newline at end of file
from .data_utils import SeeClick as seeclick
\ No newline at end of file
import torch
import torchvision
import re
import cv2
import numpy as np
import os
import yaml
from tqdm import tqdm
from PIL import Image
from data.utils.visual_trace import visual_trace
from data.utils.som_tom import som_prompting, tom_prompting
from data.conversations import Constructor
class SeeClick(Constructor):
def __init__(self, **kwargs):
super(SeeClick, self).__init__(**kwargs)
# load settings from settings.yaml file
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'settings.yaml'), 'r') as file:
self.settings = yaml.safe_load(file)
self.spatial_quant_size = kwargs.get('spatial_quant_size', 256) # this is also used for open-x
self.num_clusters = self.settings['trace_processor']['num_clusters']
self.root_dir = kwargs.get('dataset_folder', None)
self.task = kwargs.get('task', 'agent')
self.use_som_tom = kwargs.get('mm_use_som_tom', True)
self.use_som_tom_orig_img = kwargs.get('mm_use_som_tom_orig_img', False)
def __call__(self, **kwargs):
return super()._construct_conv(**kwargs)
def filter_items(self, items):
"""
Filter invalid items
"""
if self.use_som_tom and not self.use_som_tom_orig_img:
return items
elif self.use_som_tom and self.use_som_tom_orig_img:
print("Adding original image to SoM")
for item in tqdm(items):
image_path = item['image']
if "mobile" in image_path:
item['image'] = [image_path.replace("combined_image_processed", "combined")] + [item['image']]
for conv in item['conversations']:
# remove 'Mark: {id}' from the conversation conv['value'], e.g., Mark: 11
conv['value'] = conv['value'].replace("<image>", "<image>\n<image>")
elif "web" in image_path:
item['image'] = [image_path.replace("seeclick_web_imgs_processed", "seeclick_web_imgs")] + [item['image']]
for conv in item['conversations']:
# remove 'Mark: {id}' from the conversation conv['value']
conv['value'] = conv['value'].replace("<image>", "<image>\n<image>")
else:
continue
else:
print("Filtering SoM from seeclick training data")
for item in tqdm(items):
image_path = item['image']
if "mobile" in image_path:
item['image'] = image_path.replace("combined_image_processed", "combined")
for conv in item['conversations']:
# remove 'Mark: {id}' from the conversation conv['value'], e.g., Mark: 11
conv['value'] = re.sub(r' Mark: \d+', '', conv['value']).strip()
conv['value'] = re.sub(r' mark: \d+', '', conv['value']).strip()
elif "web" in image_path:
item['image'] = image_path.replace("seeclick_web_imgs_processed", "seeclick_web_imgs")
for conv in item['conversations']:
# remove 'Mark: {id}' from the conversation conv['value']
conv['value'] = re.sub(r' Mark: \d+', '', conv['value']).strip()
conv['value'] = re.sub(r' mark: \d+', '', conv['value']).strip()
else:
continue
return items
\ No newline at end of file
# tracker settings
tracker:
backward_tracking: true
ckpt_path: ./checkpoints/cotracker2.pth
grid_query_frame: 0
grid_size: 32
save_dir: ./
# sft settings
trace_processor:
num_clusters: 3
trace_planner:
quant_size: 200
skip_frames: 16
step_to_predict: 16 # use same setting as COIN since the videos have 30fps
\ No newline at end of file
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "."
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<image_patch>"
DEFAULT_IM_START_TOKEN = "<image_start>"
DEFAULT_IM_END_TOKEN = "<image_end>"
IMAGE_PLACEHOLDER = "<image-placeholder>"
import torch
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import matplotlib.pyplot as plt
def som_prompting(image, pos_traces, neg_traces, draw_som_positive=False, draw_som_negative=False):
"""
draw marks on the image
"""
image_size = image.size
draw = ImageDraw.Draw(image)
def get_text_size(text, image, font):
im = Image.new('RGB', (image.width, image.height))
draw = ImageDraw.Draw(im)
_, _, width, height = draw.textbbox((0, 0), text=text, font=font)
return width, height
def expand_bbox(bbox):
x1, y1, x2, y2 = bbox
return [x1-4, y1-4, x2+4, y2+4]
def draw_marks(draw, points, text_size, id, font_size):
txt = str(id)
draw.ellipse(((points[0]-max(text_size)//2-1, points[1]-max(text_size)//2-1, points[0]+max(text_size)//2+1, points[1]+max(text_size)//2+1)), fill='red')
draw.text((points[0]-text_size[0] // 2, points[1]-text_size[1] // 2-3), txt, fill='white', font=font_size)
fontsize = 1
font = ImageFont.truetype("data/utils/arial.ttf", fontsize)
txt = "55"
while min(get_text_size(txt, image, font)) < 0.03*image_size[0]:
# iterate until the text size is just larger than the criteria
fontsize += 1
font = ImageFont.truetype("data/utils/arial.ttf", fontsize)
text_size_2digits = get_text_size('55', image, font)
text_size_1digit = get_text_size('5', image, font)
text_size = {
1: text_size_1digit,
2: text_size_2digits
}
# draw the starting point of positive traces on image
num_pos = pos_traces.shape[2]
pos_idx = torch.arange(num_pos)
pos_traces_to_mark = pos_traces
# random sample at most 10 negative traces
num_neg = neg_traces.shape[2]
neg_idx = torch.arange(num_neg)
neg_traces_to_mark = neg_traces
num_traces_total = pos_traces_to_mark.shape[2] + neg_traces_to_mark.shape[2]
# shuffle the indices
all_idx = torch.randperm(num_traces_total)
pos_mark_ids = []; neg_mark_ids = []
pos_traces_som = {}
for i in range(pos_traces_to_mark.shape[2]):
pos = pos_traces_to_mark[:,:,i]
mark_id = all_idx[i].item()
text_size = get_text_size(str(mark_id+1), image, font)
if draw_som_positive:
draw_marks(draw, pos[0][0], text_size, mark_id+1, font)
pos_traces_som[mark_id+1] = pos
pos_mark_ids.append(mark_id+1)
neg_traces_som = {}
for i in range(neg_traces_to_mark.shape[2]):
neg = neg_traces_to_mark[:,:,i]
mark_id = all_idx[pos_traces_to_mark.shape[2]+i].item()
text_size = get_text_size(str(mark_id+1), image, font)
if draw_som_negative:
draw_marks(draw, neg[0][0], text_size, mark_id+1, font)
neg_traces_som[mark_id+1] = neg
neg_mark_ids.append(mark_id+1)
return image, pos_traces_som, neg_traces_som, pos_mark_ids, neg_mark_ids, all_idx
def som_prompting_with_priors(image, pos_traces_som, neg_traces_som, pos_mark_ids, neg_mark_ids, all_idx, step_offset=1, draw_som_positive=False, draw_som_negative=False):
"""
draw marks on the image
"""
image_size = image.size
draw = ImageDraw.Draw(image)
def get_text_size(text, image, font):
im = Image.new('RGB', (image.width, image.height))
draw = ImageDraw.Draw(im)
_, _, width, height = draw.textbbox((0, 0), text=text, font=font)
return width, height
def expand_bbox(bbox):
x1, y1, x2, y2 = bbox
return [x1-4, y1-4, x2+4, y2+4]
def draw_marks(draw, points, text_size, id, font_size):
txt = str(id)
draw.ellipse(((points[0]-max(text_size)//2-1, points[1]-max(text_size)//2-1, points[0]+max(text_size)//2+1, points[1]+max(text_size)//2+1)), fill='red')
draw.text((points[0]-text_size[0] // 2, points[1]-text_size[1] // 2-3), txt, fill='white', font=font_size)
fontsize = 1
font = ImageFont.truetype("data/utils/arial.ttf", fontsize)
txt = "55"
while min(get_text_size(txt, image, font)) < 0.02*image_size[0]:
# iterate until the text size is just larger than the criteria
fontsize += 1
font = ImageFont.truetype("data/utils/arial.ttf", fontsize)
text_size_2digits = get_text_size('55', image, font)
text_size_1digit = get_text_size('5', image, font)
text_size = {
1: text_size_1digit,
2: text_size_2digits
}
for key, val in pos_traces_som.items():
mark_id = key
pos = val[:,step_offset if step_offset < val.shape[1] else -1]
text_size = get_text_size(str(mark_id), image, font)
if draw_som_positive:
draw_marks(draw, pos[0], text_size, mark_id, font)
for key, val in neg_traces_som.items():
mark_id = key
neg = val[:,step_offset if step_offset < val.shape[1] else -1]
text_size = get_text_size(str(mark_id), image, font)
if draw_som_negative:
draw_marks(draw, neg[0], text_size, mark_id, font)
return image
def tom_prompting(trace, image, pos_traces, neg_traces, draw_som_positive=False, draw_som_negative=False):
"""
draw trace-of-marks on the image
"""
image_size = image.size
# draw traces for all points
# get all traces
tracks = torch.cat([pos_traces, neg_traces], dim=2).cpu().numpy()
_, T, N, _ = tracks.shape
vector_colors = np.zeros((T, N, 3))
if trace.visualizer.mode == "rainbow":
y_min, y_max = (
tracks[0, 0, :, 1].min(),
tracks[0, 0, :, 1].max(),
)
norm = plt.Normalize(y_min, y_max)
for n in range(N):
color = trace.visualizer.color_map(norm(tracks[0, 0, n, 1]))
color = np.array(color[:3])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with time
for t in range(T):
color = np.array(trace.visualizer.color_map(t / T)[:3])[None] * 255
vector_colors[t] = np.repeat(color, N, axis=0)
# PIL to numpy
image = np.array(image).astype(np.uint8)
# unsqueeze image to 4D
curr_tracks = tracks[0]
curr_colors = vector_colors
image = trace.visualizer._draw_pred_tracks(image, curr_tracks, curr_colors)
image = Image.fromarray(image)
return image
\ No newline at end of file
import io
import os
import cv2
import json
import torch
import numpy as np
from PIL import Image
from IPython import display
from tqdm import tqdm
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from matplotlib import cm
import faiss
from kmeans_pytorch import kmeans
class visual_trace():
def __init__(
self,
grid_size=10,
grid_query_frame=0,
linewidth=2,
backward_tracking=False,
save_dir="./videos",
):
self.grid_size = grid_size
self.grid_query_frame = grid_query_frame
self.backward_tracking = backward_tracking
self.visualizer = Visualizer(save_dir=save_dir, pad_value=0, linewidth=linewidth, tracks_leave_trace=-1)
def extract_visual_trace(self, video):
video = video.to(self.device)
pred_tracks, pred_visibility = self.model(
video,
grid_size=self.grid_size,
grid_query_frame=self.grid_query_frame,
backward_tracking=self.backward_tracking,
# segm_mask=segm_mask
)
return video, pred_tracks, pred_visibility
def visual_trace_length(self, pred_tracks, pred_visibility, image_size):
"""
Compute the length of the visual trace
pred_tracks: e.g., [1, 77, 225, 2]
pred_visibility: e.g., [1, 77, 225]
image_size: e.g., [720, 1280]
"""
pred_tracks_normalized = pred_tracks / torch.tensor(image_size).float()[None, None, None, :].to(pred_tracks.device)
pred_visiblity_float = pred_visibility[:, 1:].float().to(pred_tracks.device)
consecutive_displacement = torch.norm(pred_tracks_normalized[:, 1:] - pred_tracks_normalized[:, :-1], dim=3)
# average_displacement = (consecutive_displacement * pred_visiblity_float).sum(1) / (1e-5 + pred_visiblity_float.sum(1))
average_displacement = consecutive_displacement.mean(1)
return average_displacement
def visualize(self, video, pred_tracks, pred_visibility, filename="visual_trace.mp4", mode="ranbow"):
if mode == "rainbow":
self.visualizer.color_map = cm.get_cmap("gist_rainbow")
elif mode == "cool":
self.visualizer.color_map = cm.get_cmap(mode)
return self.visualizer.visualize(
video,
pred_tracks,
pred_visibility,
query_frame=0 if self.backward_tracking else self.grid_query_frame,
filename=filename,
)
@classmethod
def cluster_traces(self, traces, n_clusters=3):
try:
traces_for_clustering = traces[0].transpose(0, 1)
# pred_tracks_4_clustering = pred_tracks_4_clustering - pred_tracks_4_clustering[:, :1]
traces_for_clustering = traces_for_clustering.flatten(1)
kmeans = faiss.Kmeans(
traces_for_clustering.shape[1],
min(n_clusters, traces_for_clustering.shape[0]),
niter=50,
verbose=False,
min_points_per_centroid=1,
max_points_per_centroid=10000000,
)
kmeans.train(traces_for_clustering.cpu().numpy())
distances, cluster_ids_x_np = kmeans.index.search(traces_for_clustering.cpu().numpy(), 1)
cluster_ids_x = torch.from_numpy(cluster_ids_x_np).to(traces_for_clustering.device)
except:
print("kmeans failed")
return None
# sample 20% of ids or at lest 1 and at most 2 ids from each cluster
sampled_ids = cluster_ids_x.new_zeros(cluster_ids_x.size(0)).to(traces.device)
for cluster_id in range(min(n_clusters, traces_for_clustering.shape[0])):
cluster_idx = (cluster_ids_x == cluster_id).nonzero().squeeze(1)
num_pts_to_sample = max(1, min(1, int(0.2*cluster_idx.size(0))))
if num_pts_to_sample > 0:
# TODO: random sample is a bit dummy, need a better sampling algo here
sampled_idx = torch.randperm(cluster_idx.size(0))[:num_pts_to_sample]
sampled_ids[cluster_idx[sampled_idx]] = 1
return sampled_ids
@classmethod
def cluster_traces_kmeans(self, traces, n_clusters=3, positive=False):
x = traces[0].transpose(0, 1).flatten(1)
if x.shape[0] == 0:
return None
elif x.shape[0] == 1:
return torch.ones(1).to(traces.device)
cluster_ids_x, cluster_centers = kmeans(
X=x, num_clusters=min(n_clusters, x.shape[0]), distance='euclidean', device=x.device, tqdm_flag=False
)
# sample 20% of ids or at lest 1 and at most 2 ids from each cluster
sampled_ids = cluster_ids_x.new_zeros(cluster_ids_x.size(0)).to(traces.device)
for cluster_id in range(min(n_clusters, cluster_ids_x.shape[0])):
cluster_idx = (cluster_ids_x == cluster_id).nonzero().squeeze(1)
num_pts_to_sample = max(1, min(1, int(0.2*cluster_idx.size(0))))
if num_pts_to_sample > 0:
# TODO: random sample is a bit dummy, need a better sampling algo here
sampled_idx = torch.randperm(cluster_idx.size(0))[:num_pts_to_sample]
sampled_ids[cluster_idx[sampled_idx]] = 1
return sampled_ids
def remove_close_points_tensor(self, trajectory, min_distance=2):
"""
Removes points from the 2D trajectory that are closer than min_distance apart.
Parameters:
trajectory (torch.Tensor): A tensor of shape (N, 2) representing N points in 2D space.
min_distance (float): The minimum distance threshold for points to be retained.
Returns:
torch.Tensor: A filtered tensor of points where consecutive points are at least min_distance apart.
"""
# Start with the first point
filtered_trajectory = [trajectory[0]]
# Iterate through the points
for i in range(1, trajectory.size(0)):
prev_point = filtered_trajectory[-1]
curr_point = trajectory[i]
# Calculate the Euclidean distance between the previous point and the current point
distance = torch.norm(curr_point - prev_point)
# Keep the point if it's at least min_distance apart from the previous one
if distance >= min_distance:
filtered_trajectory.append(curr_point)
# Convert the filtered list back to a tensor
return torch.stack(filtered_trajectory)
\ No newline at end of file
import io
import os
import cv2
import json
import torch
import numpy as np
from PIL import Image
from IPython import display
from tqdm import tqdm
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from cotracker.predictor import CoTrackerPredictor
from matplotlib import cm
import faiss
class visual_tracker():
def __init__(
self,
grid_size=10,
grid_query_frame=0,
backward_tracking=False,
save_dir="./",
ckpt_path=None,
device='cuda'
):
self.grid_size = grid_size
self.grid_query_frame = grid_query_frame
self.backward_tracking = backward_tracking
self.device = device
print("Default device: ", device)
cotracker_checkpoint = ckpt_path
if cotracker_checkpoint is not None:
model = CoTrackerPredictor(checkpoint=cotracker_checkpoint).to(device)
else:
model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").to(device)
self.model = model
self.visualizer = Visualizer(save_dir=save_dir, pad_value=0, linewidth=1, tracks_leave_trace=-1)
def extract_visual_trace(self, video):
video = video.to(self.device)
pred_tracks, pred_visibility = self.model(
video,
grid_size=self.grid_size,
# grid_query_frame=self.grid_query_frame,
# backward_tracking=self.backward_tracking,
# segm_mask=segm_mask
)
return video, pred_tracks, pred_visibility
# def visual_trace_length(self, pred_tracks, image_size):
# """
# Compute the length of the visual trace
# pred_tracks: e.g., [1, 77, 225, 2]
# """
# distance_accum = 0
# for i in range(1, pred_tracks.size(1)):
# curr_pts = pred_tracks[0, i]
# prev_pts = pred_tracks[0, i - 1]
# delta = curr_pts - prev_pts
# distance = torch.norm(delta, dim=1)
# distance_accum += distance
# distance_accum = distance_accum * 640 / image_size[0]
# return distance_accum / pred_tracks.size(1)
def visual_trace_length(self, pred_tracks, pred_visibility, image_size):
"""
Compute the length of the visual trace
pred_tracks: e.g., [1, 77, 225, 2]
pred_visibility: e.g., [1, 77, 225]
image_size: e.g., [720, 1280]
"""
pred_tracks_normalized = pred_tracks / torch.tensor(image_size).float()[None, None, None, :].to(pred_tracks.device)
pred_visiblity_float = pred_visibility[:, 1:].float().to(pred_tracks.device)
consecutive_displacement = torch.norm(pred_tracks_normalized[:, 1:] - pred_tracks_normalized[:, :-1], dim=3)
# average_displacement = (consecutive_displacement * pred_visiblity_float).sum(1) / (1e-5 + pred_visiblity_float.sum(1))
average_displacement = consecutive_displacement.mean(1)
return average_displacement
@classmethod
def cluster_traces(self, traces, n_clusters=3):
try:
traces_for_clustering = traces[0].transpose(0, 1)
# pred_tracks_4_clustering = pred_tracks_4_clustering - pred_tracks_4_clustering[:, :1]
traces_for_clustering = traces_for_clustering.flatten(1)
kmeans = faiss.Kmeans(
traces_for_clustering.shape[1],
min(n_clusters, traces_for_clustering.shape[0]),
niter=50,
verbose=False,
min_points_per_centroid=1,
max_points_per_centroid=10000000,
)
kmeans.train(traces_for_clustering.cpu().numpy())
distances, cluster_ids_x_np = kmeans.index.search(traces_for_clustering.cpu().numpy(), 1)
cluster_ids_x = torch.from_numpy(cluster_ids_x_np).to(traces_for_clustering.device)
except:
print("kmeans failed")
return None
# sample 20% of ids or at lest 1 and at most 2 ids from each cluster
sampled_ids = cluster_ids_x.new_zeros(cluster_ids_x.size(0)).to(traces.device)
for cluster_id in range(min(n_clusters, traces_for_clustering.shape[0])):
cluster_idx = (cluster_ids_x == cluster_id).nonzero().squeeze(1)
num_pts_to_sample = max(1, min(1, int(0.2*cluster_idx.size(0))))
if num_pts_to_sample > 0:
# TODO: random sample is a bit dummy, need a better sampling algo here
sampled_idx = torch.randperm(cluster_idx.size(0))[:num_pts_to_sample]
sampled_ids[cluster_idx[sampled_idx]] = 1
return sampled_ids
def visualize(self, video, pred_tracks, pred_visibility, filename="visual_trace.mp4", mode="ranbow"):
if mode == "rainbow":
self.visualizer.color_map = cm.get_cmap("gist_rainbow")
elif mode == "cool":
self.visualizer.color_map = cm.get_cmap(mode)
return self.visualizer.visualize(
video,
pred_tracks,
pred_visibility,
query_frame=0 if self.backward_tracking else self.grid_query_frame,
filename=filename,
)
\ No newline at end of file
# a list of all the data paths
DATA_PATH:
- "/path/to/llava_v1_5_mix665k.json"
IMAGE_FOLDER:
- "/root/to/llava_v1_5_mix665k/images"
\ No newline at end of file
# a list of all the data paths
DATA_PATH:
- "/path/to/magma_820k.json"
IMAGE_FOLDER:
- "/root/to/magma_820k/images"
\ No newline at end of file
# a list of all the data paths
DATA_PATH:
- "/path/to/open-x"
IMAGE_FOLDER:
- "siglip-224px+mx-oxe-magic-soup"
LANGUAGE_PATH:
- ""
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment