Commit 0063a668 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
import torch
import torchvision
import re
import cv2
import numpy as np
import os
import yaml
from tqdm import tqdm
from PIL import Image
from data.utils.visual_trace import visual_trace
from data.utils.som_tom import som_prompting, tom_prompting
from data.conversations import Constructor
class LlaVA(Constructor):
def __init__(self, **kwargs):
super(LlaVA, self).__init__(**kwargs)
# load settings from settings.yaml file
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'settings.yaml'), 'r') as file:
self.settings = yaml.safe_load(file)
self.spatial_quant_size = kwargs.get('spatial_quant_size', 256) # this is also used for open-x
self.num_clusters = self.settings['trace_processor']['num_clusters']
self.root_dir = kwargs.get('dataset_folder', None)
self.task = kwargs.get('task', 'agent')
self.use_som_tom = kwargs.get('mm_use_som_tom', True)
self.tokenizer = kwargs.get('tokenizer', None)
self.special_tokens = [self.tokenizer.pad_token]
def __call__(self, **kwargs):
return super()._construct_conv(**kwargs)
def filter_items(self, items):
"""
Filter invalid items
"""
num_items = len(items)
print("Filtering samples containing special tokens")
for item in tqdm(items):
values = [conv['value'] for conv in item['conversations']]
# if any special token is present in the conversation, remove the item
if any([True for value in values if any([token in value for token in self.special_tokens])]):
print(item)
items.remove(item)
print(f"Removed {num_items - len(items)} items containing special tokens")
return items
\ No newline at end of file
# tracker settings
tracker:
backward_tracking: true
ckpt_path: ./checkpoints/cotracker2.pth
grid_query_frame: 0
grid_size: 32
save_dir: ./
# sft settings
trace_processor:
num_clusters: 3
trace_planner:
quant_size: 200
skip_frames: 16
step_to_predict: 16 # use same setting as COIN since the videos have 30fps
\ No newline at end of file
from .data_utils import Magma as magma
\ No newline at end of file
import torch
import torchvision
import re
import cv2
import numpy as np
import os
import yaml
from tqdm import tqdm
from PIL import Image
from data.conversations import Constructor
class Magma(Constructor):
def __init__(self, **kwargs):
super(Magma, self).__init__(**kwargs)
# load settings from settings.yaml file
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'settings.yaml'), 'r') as file:
self.settings = yaml.safe_load(file)
self.spatial_quant_size = kwargs.get('spatial_quant_size', 256) # this is also used for open-x
self.num_clusters = self.settings['trace_processor']['num_clusters']
self.root_dir = kwargs.get('dataset_folder', None)
self.task = kwargs.get('task', 'agent')
self.use_som_tom = kwargs.get('mm_use_som_tom', True)
self.tokenizer = kwargs.get('tokenizer', None)
self.special_tokens = [self.tokenizer.pad_token]
def __call__(self, **kwargs):
return super()._construct_conv(**kwargs)
def filter_items(self, items):
"""
Filter invalid items
"""
num_items = len(items)
print("Filtering samples containing special tokens")
for item in tqdm(items):
values = [conv['value'] for conv in item['conversations']]
# if any special token is present in the conversation, remove the item
if any([True for value in values if any([token in value for token in self.special_tokens])]):
print(item)
items.remove(item)
print(f"Removed {num_items - len(items)} items containing special tokens")
return items
\ No newline at end of file
# tracker settings
tracker:
backward_tracking: true
ckpt_path: ./checkpoints/cotracker2.pth
grid_query_frame: 0
grid_size: 32
save_dir: ./
# sft settings
trace_processor:
num_clusters: 3
trace_planner:
quant_size: 200
skip_frames: 16
step_to_predict: 16 # use same setting as COIN since the videos have 30fps
\ No newline at end of file
from .data_utils import OpenXDataItem
from .data_utils import OpenX as openx
\ No newline at end of file
"""
action_tokenizer.py
Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions.
"""
from typing import List, Union
import numpy as np
from transformers import PreTrainedTokenizerBase
class ActionTokenizer:
def __init__(
self, tokenizer: PreTrainedTokenizerBase, bins: int = 256, min_action: int = -1, max_action: int = 1
) -> None:
"""
Discretizes continuous robot actions into N bins per dimension and maps to the least used tokens.
NOTE =>> by default, assumes a BPE-style tokenizer akin to the LlamaTokenizer, where *the least used tokens*
appear at the end of the vocabulary!
:param tokenizer: Base LLM/VLM tokenizer to extend.
:param bins: Number of bins for each continuous value; we'll adopt a uniform binning strategy.
:param min_action: Minimum action value (for clipping, setting lower bound on bin interval).
:param max_action: Maximum action value (for clipping, setting upper bound on bin interval).
"""
self.tokenizer, self.n_bins, self.min_action, self.max_action = tokenizer, bins, min_action, max_action
# Create Uniform Bins + Compute Bin Centers
self.bins = np.linspace(min_action, max_action, self.n_bins)
self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
# [Contract] Set "action_token_begin_idx" based on `self.tokenizer.vocab_size - (self.n_bins + 1)`
# =>> Assumes we're always overwriting the final `n_bins` tokens of the vocabulary!
self.action_token_begin_idx: int = int(self.tokenizer.vocab_size - (self.n_bins + 1))
def __call__(self, action: np.ndarray) -> Union[str, List[str]]:
"""Clip & bin actions to *the last `n_bins` tokens* of the vocabulary (e.g., tokenizer.vocab[-256:])."""
action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action))
discretized_action = np.digitize(action, self.bins)
# Handle single element vs. batch
if len(discretized_action.shape) == 1:
return self.tokenizer.decode(list(self.tokenizer.vocab_size - discretized_action))
else:
return self.tokenizer.batch_decode((self.tokenizer.vocab_size - discretized_action).tolist())
def encode_actions_to_token_ids(self, action: np.ndarray) -> np.ndarray:
"""Encode continuous actions to discrete action token IDs."""
action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action))
discretized_action = np.digitize(action, self.bins)
return self.tokenizer.vocab_size - discretized_action
def encode_actions_to_discrete_ids(self, action: np.ndarray) -> np.ndarray:
"""Encode continuous actions to discrete action token IDs."""
action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action))
discretized_action = np.digitize(action, self.bins)
return discretized_action
def decode_token_ids_to_actions(self, action_token_ids: np.ndarray) -> np.ndarray:
"""
Returns continuous actions for discrete action token IDs.
NOTE =>> Because of the way the actions are discretized w.r.t. the bins (and not the bin centers), the
digitization returns bin indices between [1, # bins], inclusive, when there are actually only
(# bins - 1) bin intervals.
Therefore, if the digitization returns the last possible index, we map this to the last bin interval.
EXAMPLE =>> Let's say self._bins has 256 values. Then self._bin_centers has 255 values. Digitization returns
indices between [1, 256]. We subtract 1 from all indices so that they are between [0, 255]. There
is still one index (i==255) that would cause an out-of-bounds error if used to index into
self._bin_centers. Therefore, if i==255, we subtract 1 from it so that it just becomes the index of
the last bin center. We implement this simply via clipping between [0, 255 - 1].
"""
discretized_actions = self.tokenizer.vocab_size - action_token_ids
discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
return self.bin_centers[discretized_actions]
@property
def vocab_size(self) -> int:
return self.n_bins
from .datasets import DatasetConfig, DatasetRegistry
from .models import ModelConfig, ModelRegistry
from .vla import VLAConfig, VLARegistry
"""
datasets.py
Draccus Dataclass Definition for a DatasetConfig object, with various registered subclasses for each dataset variant
and processing scheme. A given dataset variant (e.g., `llava-lightning`) configures the following attributes:
- Dataset Variant (Identifier) --> e.g., "llava-v15"
- Align Stage Dataset Components (annotations, images)
- Finetune Stage Dataset Components (annotations, images)
- Dataset Root Directory (Path)
"""
from dataclasses import dataclass
from enum import Enum, unique
from pathlib import Path
from typing import Tuple
from draccus import ChoiceRegistry
@dataclass
class DatasetConfig(ChoiceRegistry):
# fmt: off
dataset_id: str # Unique ID that fully specifies a dataset variant
# Dataset Components for each Stage in < align | finetune >
align_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `align` stage
finetune_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `finetune` stage
dataset_root_dir: Path # Path to dataset root directory; others paths are relative to root
# fmt: on
# [Reproduction] LLaVa-v15 (exact dataset used in all public LLaVa-v15 models)
@dataclass
class LLaVa_V15_Config(DatasetConfig):
dataset_id: str = "llava-v15"
align_stage_components: Tuple[Path, Path] = (
Path("download/llava-laion-cc-sbu-558k/chat.json"),
Path("download/llava-laion-cc-sbu-558k/"),
)
finetune_stage_components: Tuple[Path, Path] = (
Path("download/llava-v1.5-instruct/llava_v1_5_mix665k.json"),
Path("download/llava-v1.5-instruct/"),
)
dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
# [Multimodal-Only] LLava-v15 WITHOUT the Language-Only ShareGPT Data (No Co-Training)
@dataclass
class LLaVa_Multimodal_Only_Config(DatasetConfig):
dataset_id: str = "llava-multimodal"
align_stage_components: Tuple[Path, Path] = (
Path("download/llava-laion-cc-sbu-558k/chat.json"),
Path("download/llava-laion-cc-sbu-558k/"),
)
finetune_stage_components: Tuple[Path, Path] = (
Path("download/llava-v1.5-instruct/llava_v1_5_stripped625k.json"),
Path("download/llava-v1.5-instruct/"),
)
dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
# LLaVa-v15 + LVIS-Instruct-4V
@dataclass
class LLaVa_LVIS4V_Config(DatasetConfig):
dataset_id: str = "llava-lvis4v"
align_stage_components: Tuple[Path, Path] = (
Path("download/llava-laion-cc-sbu-558k/chat.json"),
Path("download/llava-laion-cc-sbu-558k/"),
)
finetune_stage_components: Tuple[Path, Path] = (
Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_mix888k.json"),
Path("download/llava-v1.5-instruct/"),
)
dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
# LLaVa-v15 + LRV-Instruct
@dataclass
class LLaVa_LRV_Config(DatasetConfig):
dataset_id: str = "llava-lrv"
align_stage_components: Tuple[Path, Path] = (
Path("download/llava-laion-cc-sbu-558k/chat.json"),
Path("download/llava-laion-cc-sbu-558k/"),
)
finetune_stage_components: Tuple[Path, Path] = (
Path("download/llava-v1.5-instruct/llava_v1_5_lrv_mix1008k.json"),
Path("download/llava-v1.5-instruct/"),
)
dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
# LLaVa-v15 + LVIS-Instruct-4V + LRV-Instruct
@dataclass
class LLaVa_LVIS4V_LRV_Config(DatasetConfig):
dataset_id: str = "llava-lvis4v-lrv"
align_stage_components: Tuple[Path, Path] = (
Path("download/llava-laion-cc-sbu-558k/chat.json"),
Path("download/llava-laion-cc-sbu-558k/"),
)
finetune_stage_components: Tuple[Path, Path] = (
Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_lrv_mix1231k.json"),
Path("download/llava-v1.5-instruct/"),
)
dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
# === Define a Dataset Registry Enum for Reference & Validation =>> all *new* datasets must be added here! ===
@unique
class DatasetRegistry(Enum):
# === LLaVa v1.5 ===
LLAVA_V15 = LLaVa_V15_Config
LLAVA_MULTIMODAL_ONLY = LLaVa_Multimodal_Only_Config
LLAVA_LVIS4V = LLaVa_LVIS4V_Config
LLAVA_LRV = LLaVa_LRV_Config
LLAVA_LVIS4V_LRV = LLaVa_LVIS4V_LRV_Config
@property
def dataset_id(self) -> str:
return self.value.dataset_id
# Register Datasets in Choice Registry
for dataset_variant in DatasetRegistry:
DatasetConfig.register_subclass(dataset_variant.dataset_id, dataset_variant.value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment