Commit 0063a668 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
"""
models.py
Draccus Dataclass Definition for a ModelConfig object, with various registered subclasses for each model family and
variant thereof. A given model variant configures the following attributes:
- Pretrained Visual Representation (e.g., OpenAI CLIP ViT-L/14) + Pretrained LLM Backbone (e.g., LLaMa-2 7B)
- VLM Configuration + Parameters (e.g., MLP Projector, Image Preprocessing, etc.)
- [Optional] Stage 1 (`align`) Optimization Hyperparameters
- Stage 2 (`finetune`) Optimization Hyperparameters
"""
from dataclasses import dataclass
from enum import Enum, unique
from typing import Optional
from draccus import ChoiceRegistry
@dataclass
class ModelConfig(ChoiceRegistry):
# fmt: off
model_id: str # Unique Model ID that fully specifies a given variant
arch_specifier: str # Architecture specifier string (e.g., "gelu-mlp")
# Pretrained Backbones
vision_backbone_id: str # Pretrained Visual Featurizer (from TIMM) to load
llm_backbone_id: str # Pretrained LLM (from HF Transformers) to load
# Backbone Parameters
image_resize_strategy: str # Resizing strategy in < crop | letterbox | corner-pad >
llm_max_length: int # Maximum context length for LLM (can be < than max!)
# === Multi-Stage Optimization Hyperparameters ===
# By default, we assume an AdamW optimizer with FSDP (Gradient Sharding or Full Sharding depending on stage)
# Align Stage Optimization Parameters
align_epochs: int # Epochs to Run (in case `max_steps` is not specified)
align_max_steps: Optional[int] # [Optional] Max Gradient Steps (overrides epochs)
align_global_batch_size: int # Global Batch Size (divided across processes)
align_per_device_batch_size: int # Per-Device Batch Size (per-process)
# => # of accumulation steps is auto-computed
align_learning_rate: float # Peak Learning Rate (lr_scheduler sets warmup/decay)
align_weight_decay: float # Weight Decay for AdamW Optimizer
align_max_grad_norm: float # Max Grad Norm (for global gradient clipping)
align_lr_scheduler_type: str # LR Scheduler (default: "linear-warmup+cosine-decay")
align_warmup_ratio: float # Fraction of total steps to warmup
align_train_strategy: str # Align Train Strategy (default: "fsdp-shard-grad-op")
# Finetune Stage Optimization Parameters
finetune_epochs: int # Epochs to Run (in case `max_steps` is not specified)
finetune_max_steps: Optional[int] # [Optional] Max Gradient Steps (overrides epochs)
finetune_global_batch_size: int # Global Batch Size (divided across processes)
finetune_per_device_batch_size: int # Per-Device Batch Size (per-process)
# => # of accumulation steps is auto-computed
finetune_learning_rate: float # Peak Learning Rate (lr_scheduler sets warmup/decay)
finetune_weight_decay: float # Weight Decay for AdamW Optimizer
finetune_max_grad_norm: float # Max Grad Norm (for global gradient clipping)
finetune_lr_scheduler_type: str # LR Scheduler (default: "linear-warmup+cosine-decay")
finetune_warmup_ratio: float # Fraction of total steps to warmup
finetune_train_strategy: str # Finetune Train Strategy (default: "fsdp-full-shard")
# Enable Gradient/Activation Checkpointing (for the LLM Backbone)
enable_gradient_checkpointing: bool = True
# Enable Traditional Mixed Precision Training via Torch Native AMP (`autocast`)
enable_mixed_precision_training: bool = True # Whether to enable mixed precision training
reduce_in_full_precision: bool = False # Whether to run gradient reduction in FP32
# fmt: on
# === LLaVa v1.5 Reproduction - Fully Specified Configurations ===
@dataclass
class LLaVa_v15_Reproduction_7B(ModelConfig):
model_id: str = "reproduction-llava-v15+7b"
arch_specifier: str = "gelu-mlp"
vision_backbone_id: str = "clip-vit-l-336px"
llm_backbone_id: str = "vicuna-v15-7b"
image_resize_strategy: str = "letterbox"
llm_max_length: int = 2048
# Align Stage Optimization Parameters
align_epochs: int = 1
align_max_steps: Optional[int] = None
align_global_batch_size: int = 256
align_per_device_batch_size: int = 16
align_learning_rate: float = 1e-3
align_weight_decay: float = 0.0
align_max_grad_norm: float = 1.0
align_lr_scheduler_type: str = "linear-warmup+cosine-decay"
align_warmup_ratio: float = 0.03
align_train_strategy: str = "fsdp-shard-grad-op"
# Finetune Stage Optimization Parameters
finetune_epochs: int = 1
finetune_max_steps: Optional[int] = None
finetune_global_batch_size: int = 128
finetune_per_device_batch_size: int = 16
finetune_learning_rate: float = 2e-5
finetune_weight_decay: float = 0.1
finetune_max_grad_norm: float = 1.0
finetune_lr_scheduler_type: str = "linear-warmup+cosine-decay"
finetune_warmup_ratio: float = 0.03
finetune_train_strategy: str = "fsdp-full-shard"
@dataclass
class LLaVa_v15_Reproduction_13B(LLaVa_v15_Reproduction_7B):
model_id: str = "reproduction-llava-v15+13b"
llm_backbone_id: str = "vicuna-v15-13b"
# === Section 4.1 :: Optimization Procedure ===
# Section 4.1A :: 🚀 --> Necessity of Multi-Stage Training
@dataclass
class Exp_7B_One_Stage(LLaVa_v15_Reproduction_7B):
model_id: str = "one-stage+7b"
arch_specifier: str = "no-align+gelu-mlp"
@dataclass
class Exp_13B_One_Stage(LLaVa_v15_Reproduction_13B):
model_id: str = "one-stage+13b"
arch_specifier: str = "no-align+gelu-mlp"
# Section 4.1B :: 🛠️ --> Full Finetuning through Visual Backbones
# =>> Note :: Run with `--stage full-finetune`
@dataclass
class Exp_7B_Full_Finetune_Multi_Stage(LLaVa_v15_Reproduction_7B):
model_id: str = "full-ft-multi-stage+7b"
@dataclass
class Exp_7B_Full_Finetune_One_Stage(Exp_7B_One_Stage):
model_id: str = "full-ft-one-stage+7b"
# === Section 4.2 :: Image Processing and Visual Representations ===
# Section 4.2A :: 📸 --> Choosing a Pretrained Representation
@dataclass
class Exp_7B_IN1K_ViT_L_p16_224px(Exp_7B_One_Stage):
model_id: str = "in1k-224px+7b"
vision_backbone_id: str = "in1k-vit-l"
@dataclass
class Exp_7B_DINOv2_ViT_L_p14_224px(Exp_7B_One_Stage):
model_id: str = "dinov2-224px+7b"
vision_backbone_id: str = "dinov2-vit-l"
@dataclass
class Exp_7B_CLIP_ViT_L_p14_224px(Exp_7B_One_Stage):
model_id: str = "clip-224px+7b"
vision_backbone_id: str = "clip-vit-l"
@dataclass
class Exp_7B_SigLIP_ViT_SO_p14_224px(Exp_7B_One_Stage):
model_id: str = "siglip-224px+7b"
vision_backbone_id: str = "siglip-vit-so400m"
# Section 4.2B :: 📐 --> Choosing an Image Preprocessing Strategy
@dataclass
class Exp_7B_CLIP_ViT_L_p14_336px_Resize_Crop(Exp_7B_One_Stage):
model_id: str = "clip-336px-resize-crop+7b"
image_resize_strategy: str = "resize-crop"
@dataclass
class Exp_7B_CLIP_ViT_L_p14_336px_Resize_Naive(Exp_7B_One_Stage):
model_id: str = "clip-336px-resize-naive+7b"
image_resize_strategy: str = "resize-naive"
@dataclass
class Exp_7B_SigLIP_ViT_SO_p14_384px_Letterbox(Exp_7B_One_Stage):
model_id: str = "siglip-384px-letterbox+7b"
vision_backbone_id: str = "siglip-vit-so400m-384px"
image_resize_strategy: str = "letterbox"
@dataclass
class Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Crop(Exp_7B_One_Stage):
model_id: str = "siglip-384px-resize-crop+7b"
vision_backbone_id: str = "siglip-vit-so400m-384px"
image_resize_strategy: str = "resize-crop"
@dataclass
class Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Naive(Exp_7B_One_Stage):
model_id: str = "siglip-384px-resize-naive+7b"
vision_backbone_id: str = "siglip-vit-so400m-384px"
image_resize_strategy: str = "resize-naive"
# Section 4.2D :: 🥞 --> Stacking/Ensembling Visual Representations
@dataclass
class Exp_7B_DINOCLIP_ViT_L_p14_336px_Letterbox(Exp_7B_One_Stage):
model_id: str = "dinoclip-336px-letterbox+7b"
vision_backbone_id: str = "dinoclip-vit-l-336px"
image_resize_strategy: str = "letterbox"
arch_specifier: str = "no-align+fused-gelu-mlp"
@dataclass
class Exp_7B_DINOCLIP_ViT_L_p14_336px_Resize_Naive(Exp_7B_One_Stage):
model_id: str = "dinoclip-336px-resize-naive+7b"
vision_backbone_id: str = "dinoclip-vit-l-336px"
image_resize_strategy: str = "resize-naive"
arch_specifier: str = "no-align+fused-gelu-mlp"
@dataclass
class Exp_7B_DINOSigLIP_ViT_L_p14_384px_Letterbox(Exp_7B_One_Stage):
model_id: str = "dinosiglip-384px-letterbox+7b"
vision_backbone_id: str = "dinosiglip-vit-so-384px"
image_resize_strategy: str = "letterbox"
arch_specifier: str = "no-align+fused-gelu-mlp"
@dataclass
class Exp_7B_DINOSigLIP_ViT_L_p14_384px_Resize_Naive(Exp_7B_One_Stage):
model_id: str = "dinosiglip-384px-resize-naive+7b"
vision_backbone_id: str = "dinosiglip-vit-so-384px"
image_resize_strategy: str = "resize-naive"
arch_specifier: str = "no-align+fused-gelu-mlp"
# === Section 4.3 :: Language Models ===
# Section 4.3A :: 📝 --> Base vs. Instruct-Tuned (Chat) LLMs
@dataclass
class Exp_7B_Llama2(Exp_7B_One_Stage):
model_id: str = "llama2+7b"
llm_backbone_id: str = "llama2-7b-pure"
@dataclass
class Exp_13B_Llama2(Exp_13B_One_Stage):
model_id: str = "llama2+13b"
llm_backbone_id: str = "llama2-13b-pure"
# ~ Additional LLM Backbones :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct, Phi-2 ~
@dataclass
class Ext_Exp_7B_Llama2_Chat(Exp_7B_One_Stage):
model_id: str = "llama2-chat+7b"
llm_backbone_id: str = "llama2-7b-chat"
@dataclass
class Ext_Exp_13B_Llama2_Chat(Exp_13B_One_Stage):
model_id: str = "llama2-chat+13b"
llm_backbone_id: str = "llama2-13b-chat"
@dataclass
class Ext_Exp_7B_Mistral_V1(Exp_7B_One_Stage):
model_id: str = "mistral-v0.1+7b"
llm_backbone_id: str = "mistral-v0.1-7b-pure"
@dataclass
class Ext_Exp_7B_Mistral_Instruct_V1(Exp_7B_One_Stage):
model_id: str = "mistral-instruct-v0.1+7b"
llm_backbone_id: str = "mistral-v0.1-7b-instruct"
@dataclass
class Ext_Exp_3B_Phi_2(Exp_7B_One_Stage):
model_id: str = "phi-2+3b"
llm_backbone_id: str = "phi-2-3b"
# Section 4.3B :: ✌️ --> Co-training on Language-only Data
# =>> Note :: Run with `--dataset.type "llava-multimodal" (multimodal data only / no co-training)
@dataclass
class Exp_7B_Vicuna_No_Cotraining(Exp_7B_One_Stage):
model_id: str = "vicuna-no-cotraining+7b"
@dataclass
class Exp_7B_Llama2_No_Cotraining(Exp_7B_One_Stage):
model_id: str = "llama2-no-cotraining+7b"
llm_backbone_id: str = "llama2-7b-pure"
# === Section 4.4 :: Scaling Properties - Train Time & Data ===
# Section 4.4A :: ⏰ --> Scaling Train Time
@dataclass
class Exp_7B_1p25_Epochs(Exp_7B_One_Stage):
model_id: str = "train-1.25-epochs+7b"
finetune_max_steps: int = 6500
@dataclass
class Exp_7B_1p5_Epochs(Exp_7B_One_Stage):
model_id: str = "train-1.5-epochs+7b"
finetune_max_steps: int = 7800
@dataclass
class Exp_7B_2_Epochs(Exp_7B_One_Stage):
model_id: str = "train-2-epochs+7b"
finetune_epochs: int = 2
@dataclass
class Exp_7B_3_Epochs(Exp_7B_One_Stage):
model_id: str = "train-3-epochs+7b"
finetune_epochs: int = 3
# Section 4.4B :: 📚 --> Scaling Data
# =>> Note :: Run with `--dataset.type "llava-lvis4v"`
@dataclass
class Exp_7B_LLaVa_LVIS4V(Exp_7B_One_Stage):
model_id: str = "llava-lvis4v+7b"
# =>> Note :: Run with `--dataset.type "llava-lrv"`
@dataclass
class Exp_7B_LLaVa_LRV(Exp_7B_One_Stage):
model_id: str = "llava-lrv+7b"
# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
@dataclass
class Exp_7B_LLaVa_LVIS4V_LRV(Exp_7B_One_Stage):
model_id: str = "llava-lvis4v-lrv+7b"
# === Section 5 :: Prisms ===
# Prism-CLIP
@dataclass
class Prism_7B_CLIP_Controlled(Exp_7B_One_Stage):
model_id: str = "prism-clip-controlled+7b"
vision_backbone_id: str = "clip-vit-l-336px"
image_resize_strategy: str = "resize-naive"
llm_backbone_id: str = "llama2-7b-pure"
@dataclass
class Prism_13B_CLIP_Controlled(Exp_13B_One_Stage):
model_id: str = "prism-clip-controlled+13b"
vision_backbone_id: str = "clip-vit-l-336px"
image_resize_strategy: str = "resize-naive"
llm_backbone_id: str = "llama2-13b-pure"
# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
@dataclass
class Prism_7B_CLIP(Exp_7B_One_Stage):
model_id: str = "prism-clip+7b"
vision_backbone_id: str = "clip-vit-l-336px"
image_resize_strategy: str = "resize-naive"
llm_backbone_id: str = "llama2-7b-pure"
finetune_epochs: int = 2
# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
@dataclass
class Prism_13B_CLIP(Exp_13B_One_Stage):
model_id: str = "prism-clip+13b"
vision_backbone_id: str = "clip-vit-l-336px"
image_resize_strategy: str = "resize-naive"
llm_backbone_id: str = "llama2-13b-pure"
finetune_epochs: int = 2
# Prism-SigLIP
@dataclass
class Prism_7B_SigLIP_Controlled(Exp_7B_One_Stage):
model_id: str = "prism-siglip-controlled+7b"
vision_backbone_id: str = "siglip-vit-so400m-384px"
image_resize_strategy: str = "resize-naive"
llm_backbone_id: str = "llama2-7b-pure"
@dataclass
class Prism_13B_SigLIP_Controlled(Exp_13B_One_Stage):
model_id: str = "prism-siglip-controlled+13b"
vision_backbone_id: str = "siglip-vit-so400m-384px"
image_resize_strategy: str = "resize-naive"
llm_backbone_id: str = "llama2-13b-pure"
# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
@dataclass
class Prism_7B_SigLIP(Exp_7B_One_Stage):
model_id: str = "prism-siglip+7b"
vision_backbone_id: str = "siglip-vit-so400m-384px"
image_resize_strategy: str = "resize-naive"
llm_backbone_id: str = "llama2-7b-pure"
finetune_epochs: int = 2
# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
@dataclass
class Prism_13B_SigLIP(Exp_13B_One_Stage):
model_id: str = "prism-siglip+13b"
vision_backbone_id: str = "clip-vit-l-336px"
image_resize_strategy: str = "resize-naive"
llm_backbone_id: str = "llama2-13b-pure"
finetune_epochs: int = 2
# Prism-DINOSigLIP
@dataclass
class Prism_7B_DINOSigLIP_Controlled(Exp_7B_One_Stage):
model_id: str = "prism-dinosiglip-controlled+7b"
vision_backbone_id: str = "dinosiglip-vit-so-384px"
image_resize_strategy: str = "resize-naive"
llm_backbone_id: str = "llama2-7b-pure"
arch_specifier: str = "no-align+fused-gelu-mlp"
@dataclass
class Prism_13B_DINOSigLIP_Controlled(Exp_13B_One_Stage):
model_id: str = "prism-dinosiglip-controlled+13b"
vision_backbone_id: str = "dinosiglip-vit-so-384px"
image_resize_strategy: str = "resize-naive"
llm_backbone_id: str = "llama2-13b-pure"
arch_specifier: str = "no-align+fused-gelu-mlp"
# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
@dataclass
class Prism_7B_DINOSigLIP(Exp_7B_One_Stage):
model_id: str = "prism-dinosiglip+7b"
vision_backbone_id: str = "dinosiglip-vit-so-384px"
image_resize_strategy: str = "resize-naive"
llm_backbone_id: str = "llama2-7b-pure"
arch_specifier: str = "no-align+fused-gelu-mlp"
finetune_epochs: int = 2
# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
@dataclass
class Prism_13B_DINOSigLIP(Exp_13B_One_Stage):
model_id: str = "prism-dinosiglip+13b"
vision_backbone_id: str = "dinosiglip-vit-so-384px"
image_resize_strategy: str = "resize-naive"
llm_backbone_id: str = "llama2-13b-pure"
arch_specifier: str = "no-align+fused-gelu-mlp"
finetune_epochs: int = 2
# [Inference-Optimized] 224px Prisms
@dataclass
class Opt_7B_DINOSigLIP_ViT_SO_p14_224px_Resize_Naive(Exp_7B_One_Stage):
model_id: str = "dinosiglip-224px-resize-naive+7b"
vision_backbone_id: str = "dinosiglip-vit-so-224px"
image_resize_strategy: str = "resize-naive"
arch_specifier: str = "no-align+fused-gelu-mlp"
@dataclass
class Prism_7B_DINOSigLIP_224px_Controlled(Exp_7B_One_Stage):
model_id: str = "prism-dinosiglip-224px-controlled+7b"
vision_backbone_id: str = "dinosiglip-vit-so-224px"
image_resize_strategy: str = "resize-naive"
llm_backbone_id: str = "llama2-7b-pure"
arch_specifier: str = "no-align+fused-gelu-mlp"
# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
@dataclass
class Prism_7B_DINOSigLIP_224px(Exp_7B_One_Stage):
model_id: str = "prism-dinosiglip-224px+7b"
vision_backbone_id: str = "dinosiglip-vit-so-224px"
image_resize_strategy: str = "resize-naive"
llm_backbone_id: str = "llama2-7b-pure"
arch_specifier: str = "no-align+fused-gelu-mlp"
finetune_epochs: int = 2
# === Define a Model Registry Enum for Reference & Validation ===
@unique
class ModelRegistry(Enum):
# === LLaVa v1.5 Base Reproductions ===
REPRODUCTION_7B = LLaVa_v15_Reproduction_7B
REPRODUCTION_13B = LLaVa_v15_Reproduction_13B
# === Section 4.1 :: Optimization Procedure ===
EXP_ONE_STAGE_7B = Exp_7B_One_Stage
EXP_ONE_STAGE_13B = Exp_13B_One_Stage
EXP_FULL_FT_MULTI_STAGE = Exp_7B_Full_Finetune_Multi_Stage
EXP_FULL_FT_ONE_STAGE = Exp_7B_Full_Finetune_One_Stage
# === Section 4.2 :: Image Processing and Visual Representations ===
EXP_IN1K_224PX = Exp_7B_IN1K_ViT_L_p16_224px
EXP_DINOV2_224PX = Exp_7B_DINOv2_ViT_L_p14_224px
EXP_CLIP_224PX = Exp_7B_CLIP_ViT_L_p14_224px
EXP_SIGLIP_224PX = Exp_7B_SigLIP_ViT_SO_p14_224px
EXP_CLIP_336PX_RESIZE_CROP = Exp_7B_CLIP_ViT_L_p14_336px_Resize_Crop
EXP_CLIP_336PX_RESIZE_NAIVE = Exp_7B_CLIP_ViT_L_p14_336px_Resize_Naive
EXP_SIGLIP_384PX_LETTERBOX = Exp_7B_SigLIP_ViT_SO_p14_384px_Letterbox
EXP_SIGLIP_384PX_RESIZE_CROP = Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Crop
EXP_SIGLIP_384PX_RESIZE_NAIVE = Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Naive
EXP_DINOCLIP_336PX_LETTERBOX = Exp_7B_DINOCLIP_ViT_L_p14_336px_Letterbox
EXP_DINOCLIP_336PX_RESIZE_NAIVE = Exp_7B_DINOCLIP_ViT_L_p14_336px_Resize_Naive
EXP_DINOSIGLIP_384PX_LETTERBOX = Exp_7B_DINOSigLIP_ViT_L_p14_384px_Letterbox
EXP_DINOSIGLIP_384PX_RESIZE_NAIVE = Exp_7B_DINOSigLIP_ViT_L_p14_384px_Resize_Naive
# === Section 4.3 :: Language Models ===
EXP_LLAMA2_7B = Exp_7B_Llama2
EXP_LLAMA2_13B = Exp_13B_Llama2
# ~ Additional LLM Backbone Experiments :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct ~
EXT_EXP_LLAMA2_CHAT_7B = Ext_Exp_7B_Llama2_Chat
EXT_EXP_LLAMA2_CHAT_13B = Ext_Exp_13B_Llama2_Chat
EXT_EXP_MISTRAL_V1_7B = Ext_Exp_7B_Mistral_V1
EXT_EXP_MISTRAL_INSTRUCT_V1_7B = Ext_Exp_7B_Mistral_Instruct_V1
EXT_EXP_PHI_2_3B = Ext_Exp_3B_Phi_2
# Cotraining w/ Unimodal Data
EXP_VICUNA_NO_COTRAINING_7B = Exp_7B_Vicuna_No_Cotraining
EXP_LLAMA2_NO_COTRAINING_7B = Exp_7B_Llama2_No_Cotraining
# === Section 4.4 :: Scaling Properties - Train Time & Data ===
EXP_1P25_EPOCHS = Exp_7B_1p25_Epochs
EXP_1P5_EPOCHS = Exp_7B_1p5_Epochs
EXP_2_EPOCHS = Exp_7B_2_Epochs
EXP_3_EPOCHS = Exp_7B_3_Epochs
EXP_LLAVA_LVIS4V = Exp_7B_LLaVa_LVIS4V
EXP_LLAVA_LRV = Exp_7B_LLaVa_LRV
EXP_LLAVA_LVIS4V_LRV = Exp_7B_LLaVa_LVIS4V_LRV
# === Section 5 :: Prisms ===
PRISM_CLIP_CONTROLLED_7B = Prism_7B_CLIP_Controlled
PRISM_CLIP_CONTROLLED_13B = Prism_13B_CLIP_Controlled
PRISM_CLIP_7B = Prism_7B_CLIP
PRISM_CLIP_13B = Prism_13B_CLIP
PRISM_SIGLIP_CONTROLLED_7B = Prism_7B_SigLIP_Controlled
PRISM_SIGLIP_CONTROLLED_13B = Prism_13B_SigLIP_Controlled
PRISM_SIGLIP_7B = Prism_7B_SigLIP
PRISM_SIGLIP_13B = Prism_13B_SigLIP
PRISM_DINOSIGLIP_CONTROLLED_7B = Prism_7B_DINOSigLIP_Controlled
PRISM_DINOSIGLIP_CONTROLLED_13B = Prism_13B_DINOSigLIP_Controlled
PRISM_DINOSIGLIP_7B = Prism_7B_DINOSigLIP
PRISM_DINOSIGLIP_13B = Prism_13B_DINOSigLIP
# === Inference Optimized :: 224px Prisms ===
OPT_DINOSIGLIP_224PX_RESIZE_NAIVE = Opt_7B_DINOSigLIP_ViT_SO_p14_224px_Resize_Naive
PRISM_DINOSIGLIP_224PX_CONTROLLED_7B = Prism_7B_DINOSigLIP_224px_Controlled
PRISM_DINOSIGLIP_224PX_7B = Prism_7B_DINOSigLIP_224px
@property
def model_id(self) -> str:
return self.value.model_id
# Register Models in Choice Registry
for model_variant in ModelRegistry:
ModelConfig.register_subclass(model_variant.model_id, model_variant.value)
"""
vla.py
Draccus Dataclass Definition for a VLAConfig object, with various registered subclasses for each VLA experiment and
model configuration thereof. A given VLA model (`policy`) configures the following attributes:
- Data Mixture (e.g., Bridge, OXE_MAGIC_SOUP, etc.)
- Base VLM from Prismatic Registry (e.g., `prism-dinosiglip+7b`)
- VLA Model Architecture / Parameters (e.g., freeze vision encoder, last layer finetuning)
- Training / Optimization Hyperparameters
"""
from dataclasses import dataclass
from enum import Enum, unique
from pathlib import Path
from typing import Optional, Union
from draccus import ChoiceRegistry
@dataclass
class VLAConfig(ChoiceRegistry):
# fmt: off
vla_id: str # Unique VLA Policy ID that fully specifies a configuration variant
base_vlm: Union[str, Path] # Base VLM as ID/Path to Run Directory (e.g., `prism-dinosiglip+7b`)
freeze_vision_backbone: bool # Freeze Vision Backbone Parameters (akin to pretraining)
freeze_llm_backbone: bool # Freeze LLM Backbone parameters
unfreeze_last_llm_layer: bool # Unfreeze final layer of LLM (only takes effect if LLM is frozen)
# Data Mixture Parameters
data_mix: str # Open-X Embodiment Dataset =>> Unique Mixture ID (e.g., `bridge`)
shuffle_buffer_size: int # Size of Shuffle Buffer (100K for Bridge, 1M for OXE)
# Optimization Parameters
epochs: int # Epochs to Run (in case `max_steps` is not specified)
max_steps: Optional[int] # [Optional] Max Gradient Steps to Run (overrides `epochs`)
expected_world_size: int # Expected # of GPUs =>> allows us to gate training on hardware
global_batch_size: int # Global Batch Size (divided across processes / world size)
per_device_batch_size: int # Per-Device Batch Size (per-process / individual GPU)
# =>> # of accumulation steps is auto-computed
learning_rate: float # Peak Learning Rate (`lr_scheduler_type` sets warmup/decay)
weight_decay: float # Weight Decay for AdamW Optimizer
max_grad_norm: float # Max Grad Norm (for global gradient clipping)
lr_scheduler_type: str # LR Scheduler (usually: "constant" | "linear-warmup+cosine-decay")
warmup_ratio: float # Fraction of Steps to Warmup (for warmup LR schedulers)
train_strategy: str # Train Strategy (default "fsdp-full-shard")
# Enable Gradient/Activation Checkpointing (for the LLM Backbone)
enable_gradient_checkpointing: bool = True # Enable Gradient/Activation Checkpointing during Training
# Mixed Precision Training via Torch Native AMP (`autocast`)
enable_mixed_precision_training: bool = True # Enable Traditional BF16 Mixed Precision
reduce_in_full_precision: bool = True # Accumulate/Reduce All-Gather Gradients in FP32 Full Precision
# fmt: on
# === OpenVLA Training Configurations ===
# = [8 GPU] Fast Iteration =>> SigLIP 224px + Bridge =
@dataclass
class Exp_SigLIP_224px_Bridge(VLAConfig):
vla_id: str = "siglip-224px+mx-bridge"
base_vlm: Union[str, Path] = "siglip-224px+7b"
freeze_vision_backbone: bool = False
freeze_llm_backbone: bool = False
unfreeze_last_llm_layer: bool = False
# Data Mixture Parameters
data_mix: str = "bridge"
shuffle_buffer_size: int = 256_000
# Optimization Parameters
epochs: int = 1000
max_steps: Optional[int] = None
expected_world_size: int = 8
global_batch_size: int = 256
per_device_batch_size: int = 32
learning_rate: float = 2e-5
weight_decay: float = 0.0
max_grad_norm: float = 1.0
lr_scheduler_type: str = "constant"
warmup_ratio: float = 0.0
train_strategy: str = "fsdp-full-shard"
# = [8 GPU] SigLIP 224px Frozen Vision Backbone + Bridge =
@dataclass
class Exp_FreezeVIT_SigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge):
vla_id: str = "siglip-224px-icy+mx-bridge"
base_vlm: Union[str, Path] = "siglip-224px+7b"
freeze_vision_backbone: bool = True
# = [8 GPU] Fast Iteration =>> DINO-SigLIP 224px + Bridge =
@dataclass
class Exp_DinoSigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge):
vla_id: str = "prism-dinosiglip-224px+mx-bridge"
base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b"
data_mix: str = "bridge"
expected_world_size: int = 1
global_batch_size: int = 32
per_device_batch_size: int = 32
# = [64 GPU] SigLIP 224px + OXE Magic Soup =
@dataclass
class Exp_SigLIP_224px_OXE_Magic_Soup(Exp_SigLIP_224px_Bridge):
vla_id: str = "siglip-224px+mx-oxe-magic-soup"
base_vlm: Union[str, Path] = "siglip-224px+7b"
data_mix: str = "oxe_magic_soup"
expected_world_size: int = 64
global_batch_size: int = 2048
per_device_batch_size: int = 32
# = [64 GPU] DINO-SigLIP 224px + OXE Magic Soup++ =
@dataclass
class Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus(Exp_SigLIP_224px_Bridge):
vla_id: str = "prism-dinosiglip-224px+mx-oxe-magic-soup-plus"
base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b"
# Note =>> We adopt two stages, training on a mixture including DROID for 70% of training, before resampling!
# data_mix: str = "oxe_magic_soup_plus"
data_mix: str = "oxe_magic_soup_plus_minus"
expected_world_size: int = 64
global_batch_size: int = 2048
per_device_batch_size: int = 32
# === OpenVLA Fine-tuning Configurations ===
# = [8 GPU] SigLIP 224px + T-DROID =
@dataclass
class Exp_SigLIP_224px_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
vla_id: str = "siglip-224px+mx-tdroid_carrot_in_bowl"
base_vlm: Union[str, Path] = "siglip-224px+7b"
data_mix: str = "tdroid_carrot_in_bowl"
@dataclass
class Exp_SigLIP_224px_TDROID_PourCornInPot(Exp_SigLIP_224px_Bridge):
vla_id: str = "siglip-224px+mx-tdroid_pour_corn_in_pot"
base_vlm: Union[str, Path] = "siglip-224px+7b"
data_mix: str = "tdroid_pour_corn_in_pot"
# = [8 GPU] SigLIP 224px + T-DROID -- Partial Finetuning =
@dataclass
class Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
vla_id: str = "siglip-224px-icy+mx-tdroid_carrot_in_bowl"
base_vlm: Union[str, Path] = "siglip-224px+7b"
freeze_vision_backbone: bool = True
freeze_llm_backbone: bool = False
data_mix: str = "tdroid_carrot_in_bowl"
@dataclass
class Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
vla_id: str = "siglip-224px-last_layer+mx-tdroid_carrot_in_bowl"
base_vlm: Union[str, Path] = "siglip-224px+7b"
freeze_vision_backbone: bool = True
freeze_llm_backbone: bool = True
unfreeze_last_llm_layer: bool = True
data_mix: str = "tdroid_carrot_in_bowl"
@dataclass
class Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
vla_id: str = "siglip-224px-sandwich+mx-tdroid_carrot_in_bowl"
base_vlm: Union[str, Path] = "siglip-224px+7b"
freeze_vision_backbone: bool = False
freeze_llm_backbone: bool = True
unfreeze_last_llm_layer: bool = True
data_mix: str = "tdroid_carrot_in_bowl"
# === [8 GPU] SigLIP 224px + FrankaWipe ===
@dataclass
class Exp_SigLIP_224px_Droid_Wipe(Exp_SigLIP_224px_Bridge):
vla_id: str = "siglip-224px+mx-droid_wipe"
base_vlm: Union[str, Path] = "siglip-224px+7b"
data_mix: str = "droid_wipe"
# === Define a VLA Registry Enum for Reference & Validation ===
@unique
class VLARegistry(Enum):
# Sanity Check Configurations =>> BridgeV2
SIGLIP_224PX_MX_BRIDGE = Exp_SigLIP_224px_Bridge
DINOSIGLIP_224PX_MX_BRIDGE = Exp_DinoSigLIP_224px_Bridge
# SigLIP Frozen Backbone Experiment
FREEZE_SIGLIP_224PX_MX_BRIDGE = Exp_FreezeVIT_SigLIP_224px_Bridge
# [OpenVLA v0.1 7B] SigLIP 224px + OXE Magic Soup
SIGLIP_224PX_MX_OXE_MAGIC_SOUP = Exp_SigLIP_224px_OXE_Magic_Soup
# [OpenVLA 7B] DINO + SigLIP 224px + OXE Magic Soup++
DINOSIGLIP_224PX_MX_OXE_MAGIC_SOUP_PLUS = Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus
# === TDROID Fine-tuning Configs ===
SIGLIP_224PX_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_TDROID_CarrotInBowl
SIGLIP_224PX_MX_TDROID_POUR_CORN_IN_POT = Exp_SigLIP_224px_TDROID_PourCornInPot
SIGLIP_224PX_ICY_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl
SIGLIP_224PX_LASTLAYER_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl
SIGLIP_224PX_SANDWICH_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl
# === DROID Fine-tuning Configs ===
SIGLIP_224PX_MX_DROID_WIPE = Exp_SigLIP_224px_Droid_Wipe
@property
def vla_id(self) -> str:
return self.value.vla_id
# Register VLAs in Choice Registry
for vla_variant in VLARegistry:
VLAConfig.register_subclass(vla_variant.vla_id, vla_variant.value)
import torch
import torchvision
import re
import cv2
import numpy as np
import os
import yaml
import logging
from PIL import Image
import torch.distributed as dist
from data.utils.visual_trace import visual_trace
from data.utils.som_tom import som_prompting, tom_prompting
from data.conversations import Constructor
from .conf import VLAConfig, VLARegistry
from dataclasses import dataclass, field
from magma.processing_magma import MagmaProcessor
from .materialize import get_vla_dataset_and_collator
from .datasets.rlds.utils.data_utils import save_dataset_statistics
from data.utils.visual_tracker import visual_tracker
logger = logging.getLogger(__name__)
"""
data_utils.py
General utilities and classes for facilitating data loading and collation.
"""
from dataclasses import dataclass
from typing import Callable, Dict, Sequence, Tuple
import torch
from torch.nn.utils.rnn import pad_sequence
from torch import distributed as dist
# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
IGNORE_INDEX = -100
@dataclass
class OpenXDataItem:
def __call__(self, data_root_dir, data_soup, processor=None, conversation_lib=None, image_aug=False, local_run=False, future_action_window_size=1):
# VLAConfig (`prismatic/conf/vla.py`); override with --vla.type `VLARegistry.<VLA>.vla_id`
self.openx_data_cfg = VLAConfig.get_choice_class(data_soup)
default_image_resolution = processor.image_processor.base_img_size
logger.info(f"Creating VLA Open-X Dataset with Mixture `{self.openx_data_cfg.data_mix}`")
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'settings.yaml'), 'r') as file:
self.settings = yaml.safe_load(file)
# get local rank for distributed training
rank = dist.get_rank() if dist.is_initialized() else 0
rank = rank % torch.cuda.device_count()
openx_dataset, action_tokenizer, collator = get_vla_dataset_and_collator(
data_root_dir,
self.openx_data_cfg.data_mix,
shuffle_buffer_size=1 if (local_run or future_action_window_size>1) else self.openx_data_cfg.shuffle_buffer_size,
image_transform=processor.image_processor,
visual_tracker=visual_tracker(**self.settings.get('tracker', None), device=f"cuda:{rank}"),
dataset_settings=self.settings,
tokenizer=processor.tokenizer,
default_image_resolution=(3, default_image_resolution, default_image_resolution),
image_aug=image_aug,
future_action_window_size=future_action_window_size,
prompt_builder_fn=conversation_lib, # vlm.llm_backbone.prompt_builder_fn,
local_run=local_run,
)
# Save dataset statistics for de-normalization at inference time
# if overwatch.is_rank_zero():
# save_dataset_statistics(openx_dataset.dataset_statistics, run_dir)
return openx_dataset
class OpenX(Constructor):
def __init__(self, **kwargs):
super(OpenX, self).__init__(**kwargs)
# load settings from settings.yaml file
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'settings.yaml'), 'r') as file:
self.settings = yaml.safe_load(file)
self.spatial_quant_size = kwargs.get('spatial_quant_size', 256) # this is also used for open-x
self.num_clusters = self.settings['trace_processor']['num_clusters']
self.root_dir = kwargs.get('dataset_folder', None)
def __call__(self, **kwargs):
return super()._construct_conv(**kwargs)
def filter_items(self, items):
"""
Filter invalid items
"""
return items
from .datasets import DummyDataset, EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset
"""
datasets.py
Lightweight PyTorch Dataset Definition for wrapping RLDS TFDS Pipeline; just defines transform from RLDS default
format to OpenVLA, IterableDataset shim.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Tuple, Type
import collections
import os
import numpy as np
import cv2
import torch
from PIL import Image
import torchvision
from torchvision.transforms import transforms
from torch.utils.data import Dataset, IterableDataset
from transformers import PreTrainedTokenizerBase
from data.utils.som_tom import som_prompting, tom_prompting
# from prismatic.models.backbones.llm.prompting import PromptBuilder
# from prismatic.models.backbones.vision import ImageTransform
from ..action_tokenizer import ActionTokenizer
from .rlds import make_interleaved_dataset, make_single_dataset
from .rlds.oxe import OXE_NAMED_MIXTURES, get_oxe_dataset_kwargs_and_weights
from .rlds.utils.data_utils import NormalizationType
# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
IGNORE_INDEX = -100
from typing import Callable, Dict, Sequence, Tuple
def tree_map(fn: Callable, tree: dict) -> dict:
"""Maps a function over a nested dictionary."""
return {k: tree_map(fn, v) if isinstance(v, dict) else fn(v) for k, v in tree.items()}
def tree_map_with_key(fn: Callable, tree: dict, keys: Sequence = ()) -> dict:
"""Maps a function over a nested dictionary."""
return {
k: tree_map_with_key(fn, v, (*keys, k)) if isinstance(v, dict) else fn((*keys, k), v) for k, v in tree.items()
}
@dataclass
class RLDSBatchTransform:
action_tokenizer: ActionTokenizer
base_tokenizer: PreTrainedTokenizerBase
image_transform: None # ImageTransform
prompt_builder_fn: None # Type[PromptBuilder]
visual_tracker: None
dataset_settings: None
data_root_dir: str = "/mnt/vlpdatasets"
predict_stop_token: bool = True
trace_folder: str = "open-x-traces-v2"
image_folder: str = "open-x-images-v2"
local_run: bool = False
def __call__(self, rlds_batch: Dict[str, Any]) -> Dict[str, Any]:
"""Converts a RLDS batch to the format expected by the OpenVLA collator/models."""
dataset_name, action = rlds_batch["dataset_name"], rlds_batch["action"][0]
img = Image.fromarray(rlds_batch["observation"]["image_primary"][0])
imgs_future = [Image.fromarray(img) for img in rlds_batch["observation_future"]["image_primary"]]
lang = rlds_batch["task"]["language_instruction"].decode().lower()
traj_index = rlds_batch['_traj_index']
frame_index = rlds_batch['_frame_index']
action_token_ids = self.action_tokenizer.encode_actions_to_token_ids(action)
# Construct Chat-based Prompt =>> Input is default query + language instruction, output are the action tokens
convs = [
{"role": "system", "content": "You are agent that can see, talk and act."},
{"role": "user", "content": f"<image>\nWhat action should the robot take to {lang}?"},
{"role": "assistant", "content": "<action>"},
]
prompt = self.base_tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=False)
# Tokenize (w/ `base_tokenizer`)
input_ids = self.base_tokenizer(prompt, add_special_tokens=True).input_ids
action_token_len = len(action_token_ids)
action_placeholder_token_id = self.base_tokenizer.convert_tokens_to_ids("<action>")
# # replace the action_placeholder_token_id with action_token_ids in input_ids
input_ids = list(input_ids)
input_ids_filled = []
for i, token_id in enumerate(input_ids):
if token_id == action_placeholder_token_id:
input_ids_filled.extend(action_token_ids.tolist())
else:
input_ids_filled.append(token_id)
# Tensorize =>> Run Image Transform to get `pixel_values` =>> Return
# =>> IMPORTANT :: IF WE'RE USING HF LLM.forward(..., labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL!
input_ids, labels = torch.tensor(input_ids_filled), torch.tensor(input_ids_filled)
pixel_values = transforms.Compose([transforms.ToTensor()])(img)
image_pt = self.image_transform(img, return_tensors='pt')
images = collections.defaultdict(list)
for key, val in image_pt.items():
images[key].append(val)
pixel_values_future = torch.stack([transforms.Compose([transforms.ToTensor()])(item) for item in imgs_future], dim=0)
# extract visual traces
# trace_folder = self.trace_folder
# if not os.path.exists(trace_folder):
# trace_folder = "./open-x-traces-v2"
# trace_file = f"{dataset_name}/{traj_index}/{frame_index}.pth"
# trace_path = os.path.join(trace_folder, trace_file)
# if not os.path.exists(trace_path):
# pixel_values_seq = torch.cat([pixel_values.unsqueeze(0), pixel_values_future], dim=0).unsqueeze(0)
# out = self.visual_tracker.extract_visual_trace(pixel_values_seq*255)
# # self.visual_tracker.visualize(*out)
# # save the visual trace to disk
# trace_info = {
# 'dataset_name': dataset_name,
# 'traj_index': traj_index,
# 'frame_index': frame_index,
# 'lang': lang,
# 'action': action,
# 'trace_info': out[1:]
# }
# os.makedirs(os.path.dirname(trace_path), exist_ok=True)
# torch.save(trace_info, trace_path)
# save image
# image_folder = self.image_folder
# if not os.path.exists(image_folder):
# image_folder = "./open-x-images-v2"
# image_file = f"{dataset_name}/{traj_index}/{frame_index}.jpg"
# image_path = os.path.join(image_folder, image_file)
# if not os.path.exists(image_path):
# os.makedirs(os.path.dirname(image_path), exist_ok=True)
# img.save(image_path)
# [CRITICAL] We do not want to take the loss for anything but the predicted action tokens!
# NOTE: we add 2 to the length of the action to account for the \n\n and <|eot_id|> tokens!
labels[: -(action_token_len + 2)] = IGNORE_INDEX
if not self.predict_stop_token:
labels[-1] = IGNORE_INDEX
return dict(pixel_values=images['pixel_values'], image_sizes=images['image_sizes'], pixel_values_future=pixel_values_future, input_ids=input_ids, labels=labels, dataset_name=dataset_name)
# return dict(pixel_values=pixel_values, pixel_values_future=pixel_values_future, action=action, conversation=conversation, dataset_name=dataset_name)
class RLDSDataset(IterableDataset):
def __init__(
self,
data_root_dir: Path,
data_mix: str,
batch_transform: RLDSBatchTransform,
resize_resolution: Tuple[int, int],
shuffle_buffer_size: int = 256_000,
train: bool = True,
image_aug: bool = False,
future_action_window_size: int = 0,
) -> None:
"""Lightweight wrapper around RLDS TFDS Pipeline for use with PyTorch/OpenVLA Data Loaders."""
self.data_root_dir, self.data_mix, self.batch_transform = data_root_dir, data_mix, batch_transform
# Configure RLDS Dataset(s)
if self.data_mix in OXE_NAMED_MIXTURES:
mixture_spec = OXE_NAMED_MIXTURES[self.data_mix]
else:
# Assume that passed "mixture" name is actually a single dataset -- create single-dataset "mix"
mixture_spec = [(self.data_mix, 1.0)]
# fmt: off
per_dataset_kwargs, weights = get_oxe_dataset_kwargs_and_weights(
self.data_root_dir,
mixture_spec,
load_camera_views=("primary",),
load_depth=False,
load_proprio=False,
load_language=True,
action_proprio_normalization_type=NormalizationType.BOUNDS_Q99,
)
rlds_config = dict(
traj_transform_kwargs=dict(
window_size=1, # If we wanted to feed / predict more than one step
future_action_window_size=future_action_window_size, # For action chunking
skip_unlabeled=True, # Skip trajectories without language labels
goal_relabeling_strategy="uniform", # Goals are currently unused
),
frame_transform_kwargs=dict(
resize_size=resize_resolution,
num_parallel_calls=16, # For CPU-intensive ops (decoding, resizing, etc.)
),
dataset_kwargs_list=per_dataset_kwargs,
shuffle_buffer_size=shuffle_buffer_size,
sample_weights=weights,
balance_weights=True,
traj_transform_threads=len(mixture_spec),
traj_read_threads=len(mixture_spec),
train=train,
)
# If applicable, enable image augmentations
if image_aug:
rlds_config["frame_transform_kwargs"].update({"image_augment_kwargs" : dict(
random_resized_crop=dict(scale=[0.9, 0.9], ratio=[1.0, 1.0]),
random_brightness=[0.2],
random_contrast=[0.8, 1.2],
random_saturation=[0.8, 1.2],
random_hue=[0.05],
augment_order=[
"random_resized_crop",
"random_brightness",
"random_contrast",
"random_saturation",
"random_hue",
],
)}),
# fmt: on
# Initialize RLDS Dataset
self.dataset, self.dataset_length, self.dataset_statistics = self.make_dataset(rlds_config)
def make_dataset(self, rlds_config):
return make_interleaved_dataset(**rlds_config)
def __iter__(self) -> Dict[str, Any]:
for rlds_batch in self.dataset.as_numpy_iterator():
yield self.batch_transform(rlds_batch)
def __len__(self) -> int:
return self.dataset_length
# === Explicitly Unused ===
def __getitem__(self, idx: int) -> None:
raise NotImplementedError("IterableDataset does not implement map-style __getitem__; see __iter__ instead!")
class EpisodicRLDSDataset(RLDSDataset):
"""Returns full episodes as list of steps instead of individual transitions (useful for visualizations)."""
def make_dataset(self, rlds_config):
per_dataset_kwargs = rlds_config["dataset_kwargs_list"]
assert len(per_dataset_kwargs) == 1, "Only support single-dataset `mixes` for episodic datasets."
return make_single_dataset(
per_dataset_kwargs[0],
train=rlds_config["train"],
traj_transform_kwargs=rlds_config["traj_transform_kwargs"],
frame_transform_kwargs=rlds_config["frame_transform_kwargs"],
)
def __iter__(self) -> Dict[str, Any]:
for rlds_batch in self.dataset.as_numpy_iterator():
out = [
self.batch_transform(tree_map(lambda x: x[i], rlds_batch)) # noqa: B023
for i in range(rlds_batch["action"].shape[0])
]
yield out
class DummyDataset(Dataset):
def __init__(
self,
action_tokenizer: ActionTokenizer,
base_tokenizer: PreTrainedTokenizerBase,
image_transform: None, # ImageTransform,
prompt_builder_fn: None # Type[PromptBuilder],
) -> None:
self.action_tokenizer = action_tokenizer
self.base_tokenizer = base_tokenizer
self.image_transform = image_transform
self.prompt_builder_fn = prompt_builder_fn
# Note =>> We expect the dataset to store statistics for action de-normalization. Specifically, we store the
# per-dimension 1st and 99th action quantile. The values below correspond to "no normalization" for simplicity.
self.dataset_statistics = {
"dummy_dataset": {
"action": {"q01": np.zeros((7,), dtype=np.float32), "q99": np.ones((7,), dtype=np.float32)}
}
}
def __len__(self):
# TODO =>> Replace with number of elements in your dataset!
return 10000
def __getitem__(self, idx):
# TODO =>> Load image, action and instruction from disk -- we use dummy values
image = Image.fromarray(np.asarray(np.random.rand(224, 224, 3) * 255.0, dtype=np.uint8))
action = np.asarray(np.random.rand(7), dtype=np.float32)
instruction = "do something spectacular"
# Add instruction to VLA prompt
prompt_builder = self.prompt_builder_fn("openvla")
conversation = [
{"from": "human", "value": f"What action should the robot take to {instruction}?"},
{"from": "gpt", "value": self.action_tokenizer(action)},
]
for turn in conversation:
prompt_builder.add_turn(turn["from"], turn["value"])
# Tokenize (w/ `base_tokenizer`)
input_ids = self.base_tokenizer(prompt_builder.get_prompt(), add_special_tokens=True).input_ids
labels = list(input_ids)
# Tensorize =>> Run Image Transform to get `pixel_values` =>> Return
# =>> IMPORTANT :: IF WE'RE USING HF .forward(..., labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL!
input_ids, labels = torch.tensor(input_ids), torch.tensor(labels)
pixel_values = self.image_transform(image)
# [CRITICAL] We do not want to take the loss for anything but the predicted action tokens!
labels[: -(len(action) + 1)] = IGNORE_INDEX
return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels)
from .dataset import make_interleaved_dataset, make_single_dataset
"""
dataset.py
Core interface script for configuring and initializing RLDS datasets.
"""
import copy
import inspect
import json
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Union
import logging
import torch.distributed as dist
import dlimp as dl
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from data.openx.datasets.rlds import obs_transforms, traj_transforms
from data.openx.datasets.rlds.utils import goal_relabeling, task_augmentation
from data.openx.datasets.rlds.utils.data_utils import (
NormalizationType,
allocate_threads,
get_dataset_statistics,
normalize_action_and_proprio,
pprint_data_mixture,
tree_map,
)
logger = logging.getLogger(__name__)
# Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch)
tf.config.set_visible_devices([], "GPU")
def partition_dataset(dataset, num_partitions, partition_index):
# Calculate the size of each partition
# total_samples = len(dataset)
# partition_size = total_samples // num_partitions
# # Get the start and end indices for this partition
# start_index = partition_index * partition_size
# end_index = start_index + partition_size if partition_index != num_partitions - 1 else total_samples
# # Partition using skip() and take()
# partitioned_dataset = dataset.skip(start_index).take(end_index - start_index).cache()
partitioned_dataset = dataset.shard(num_shards=num_partitions, index=partition_index)
return partitioned_dataset
# ruff: noqa: B006
def make_dataset_from_rlds(
name: str,
data_dir: str,
*,
train: bool,
standardize_fn: Optional[Callable[[dict], dict]] = None,
shuffle: bool = True,
image_obs_keys: Dict[str, Optional[str]] = {},
depth_obs_keys: Dict[str, Optional[str]] = {},
state_obs_keys: List[Optional[str]] = (),
language_key: Optional[str] = None,
action_proprio_normalization_type: NormalizationType = NormalizationType.NORMAL,
dataset_statistics: Optional[Union[dict, str]] = None,
absolute_action_mask: Optional[List[bool]] = None,
action_normalization_mask: Optional[List[bool]] = None,
num_parallel_reads: int = tf.data.AUTOTUNE,
num_parallel_calls: int = tf.data.AUTOTUNE,
) -> Tuple[dl.DLataset, dict]:
"""
This function is responsible for loading a specific RLDS dataset from storage and getting it into a standardized
format. Yields a dataset of trajectories. Does not include CPU-intensive operations.
If `standardize_fn` is provided, it will be applied to each trajectory. This function should get the trajectory
into a standard format, which includes the keys "observation" and "action". Entry "observation" should be a
dictionary containing some number of additional keys, which will be extracted into an even more standardized format
according to the "*_obs_keys" arguments.
The `image_obs_keys` and `depth_obs_keys` arguments are mappings from new names to old names, or None in place of an
old name to insert padding. For example, if after `standardize_fn`, your "observation" dict has RGB images called
"workspace" and "wrist", and `image_obs_keys={"primary": "workspace", "secondary": None, "wrist": "wrist"}`, then
the resulting dataset will have an "observation" dict containing the keys "image_primary", "image_secondary", and
"image_wrist", where "image_primary" corresponds to "workspace", "image_secondary" is a padding image, and
"image_wrist" corresponds to "wrist".
Entry `state_obs_keys` is a list of 1-dimensional proprioceptive keys to concatenate into a single array, which will
be placed in the "proprio" key of the "observation" dict. A single padding element (zero) will be inserted for each
None entry.
The dataset will also include a "task" dict. If `language_key` is provided, then the "task" dict will contain the
key "language_instruction", extracted from `traj[language_key]`.
Args:
name (str): The name of the RLDS dataset (usually "name" or "name:version").
data_dir (str): The path to the data directory.
train (bool): Whether to use the training or validation split.
shuffle (bool, optional): Whether to shuffle the file read order (does NOT fully shuffle the dataset, since one
file usually contains many trajectories)!
standardize_fn (Callable[[dict], dict], optional): A function that, if provided, will be the first
thing applied to each trajectory.
image_obs_keys (Mapping[str, str|None]): Mapping from {new: old} indicating which RGB images to extract from the
"observation" dict. `new_obs = {f"image_{new}": old_obs[old] for new, old in image_obs_keys.items()}`.
If a value of `old` is None, inserts a padding image instead (empty string).
depth_obs_keys (Mapping[str, str|None]): Same as `image_obs_keys`, but for depth images. Keys will be
prefixed with "depth_" instead of "image_".
state_obs_keys (Sequence[str|None]): List of 1-dimensional proprioception keys to be extracted from the
"observation" dict, concatenated, and mapped to "proprio". Inserts 1 element of padding for each None entry.
language_key (str, optional): If provided, the "task" dict will contain the key "language_instruction",
extracted from `traj[language_key]`.
action_proprio_normalization_type (str, optional): The type of normalization to perform on the action,
proprio, or both. Can be "normal" (mean 0, std 1) or "bounds" (normalized to [-1, 1]).
dataset_statistics: (dict|str, optional): dict (or path to JSON file) that contains dataset statistics
for normalization. If `action_proprio_normalization_type` is "normal", this should contain "mean" and
"std" keys. If `action_proprio_normalization_type` is "bounds", this should contain "min" and "max"
keys. May also provide "num_transitions" and "num_trajectories" keys for downstream usage (e.g., for
`make_interleaved_dataset`). If not provided, the statistics will be computed on the fly.
absolute_action_mask (Sequence[bool], optional): By default, all action dimensions are assumed to be
relative. This is important for when `future_action_window_size > 0`: actions that are taken
from beyond the end of the trajectory (or beyond the goal timestep when goal relabeling is used)
need to be made "neutral" to indicate that the task has been completed. For relative actions,
"neutral" means zero, but for absolute actions, "neutral" means repeating the last valid action.
This mask, if provided, indicates which action dimensions are absolute.
action_normalization_mask (Sequence[bool], optional): If provided, indicates which action dimensions
should be normalized. For example, you might not want to normalize the gripper action dimension if
it's always exactly 0 or 1. By default, all action dimensions are normalized.
num_parallel_reads (int): number of parallel read workers. Default to AUTOTUNE.
num_parallel_calls (int): number of parallel calls for traj_map operations. Default to AUTOTUNE.
Returns:
Dataset of trajectories where each step has the following fields:
- observation:
- image_{name1, name2, ...} # RGB image observations
- depth_{name1, name2, ...} # depth image observations
- proprio # 1-dimensional array of proprioceptive observations
- timestep # timestep of each frame
- task:
- language_instruction # language instruction, present if `language_key` is provided
- action # action vector
- dataset_name # name of the dataset
"""
REQUIRED_KEYS = {"observation", "action"}
if language_key is not None:
REQUIRED_KEYS.add(language_key)
def restructure(traj):
# apply a standardization function, if provided
if standardize_fn is not None:
traj = standardize_fn(traj)
if not all(k in traj for k in REQUIRED_KEYS):
raise ValueError(
f"Trajectory is missing keys: {REQUIRED_KEYS - set(traj.keys())}. " "Did you write a `standardize_fn`?"
)
# extracts images, depth images and proprio from the "observation" dict
traj_len = tf.shape(traj["action"])[0]
old_obs = traj["observation"]
new_obs = {}
for new, old in image_obs_keys.items():
if old is None:
new_obs[f"image_{new}"] = tf.repeat("", traj_len) # padding
else:
new_obs[f"image_{new}"] = old_obs[old]
for new, old in depth_obs_keys.items():
if old is None:
new_obs[f"depth_{new}"] = tf.repeat("", traj_len) # padding
else:
new_obs[f"depth_{new}"] = old_obs[old]
if state_obs_keys:
new_obs["proprio"] = tf.concat(
[
(
tf.zeros((traj_len, 1), dtype=tf.float32) # padding
if key is None
else tf.cast(old_obs[key], tf.float32)
)
for key in state_obs_keys
],
axis=1,
)
# add timestep info
new_obs["timestep"] = tf.range(traj_len)
# extracts `language_key` into the "task" dict
task = {}
if language_key is not None:
if traj[language_key].dtype != tf.string:
raise ValueError(
f"Language key {language_key} has dtype {traj[language_key].dtype}, " "but it must be tf.string."
)
task["language_instruction"] = traj.pop(language_key)
traj = {
"observation": new_obs,
"task": task,
"action": tf.cast(traj["action"], tf.float32),
"dataset_name": tf.repeat(name, traj_len),
"_traj_index": traj['_traj_index'],
"_frame_index": traj['_frame_index'],
}
if absolute_action_mask is not None:
if len(absolute_action_mask) != traj["action"].shape[-1]:
raise ValueError(
f"Length of absolute_action_mask ({len(absolute_action_mask)}) "
f"does not match action dimension ({traj['action'].shape[-1]})."
)
traj["absolute_action_mask"] = tf.tile(
tf.convert_to_tensor(absolute_action_mask, dtype=tf.bool)[None],
[traj_len, 1],
)
return traj
builder = tfds.builder(name, data_dir=data_dir)
# load or compute dataset statistics
if isinstance(dataset_statistics, str):
with tf.io.gfile.GFile(dataset_statistics, "r") as f:
dataset_statistics = json.load(f)
elif dataset_statistics is None:
full_dataset = dl.DLataset.from_rlds(
builder, split="all", shuffle=False, num_parallel_reads=num_parallel_reads
).traj_map(restructure, num_parallel_calls)
# tries to load from cache, otherwise computes on the fly
dataset_statistics = get_dataset_statistics(
full_dataset,
hash_dependencies=(
str(builder.info),
str(state_obs_keys),
inspect.getsource(standardize_fn) if standardize_fn is not None else "",
),
save_dir=builder.data_dir,
)
dataset_statistics = tree_map(np.array, dataset_statistics)
# skip normalization for certain action dimensions
if action_normalization_mask is not None:
if len(action_normalization_mask) != dataset_statistics["action"]["mean"].shape[-1]:
raise ValueError(
f"Length of skip_normalization_mask ({len(action_normalization_mask)}) "
f"does not match action dimension ({dataset_statistics['action']['mean'].shape[-1]})."
)
dataset_statistics["action"]["mask"] = np.array(action_normalization_mask)
# construct the dataset
if "val" not in builder.info.splits:
split = "train[:95%]" if train else "train[95%:]"
else:
split = "train" if train else "val"
dataset = dl.DLataset.from_rlds(builder, split=split, shuffle=shuffle, num_parallel_reads=num_parallel_reads)
# # Function to add episode_id to each trajectory
def add_episode_id(index, element):
# Add the index to the trajectory dictionary (you can modify as needed)
element['episode_id'] = index
return element
# # Example usage of adding a global index to a dataset
def add_episode_id_to_dataset(dataset):
# Enumerate the dataset to get a unique global index for each element
dataset_with_index = dataset.enumerate().map(
add_episode_id, num_parallel_calls=tf.data.AUTOTUNE
)
return dataset_with_index
# dataset = add_episode_id_to_dataset(dataset)
dataset = dataset.traj_map(restructure, num_parallel_calls)
dataset = dataset.traj_map(
partial(
normalize_action_and_proprio,
metadata=dataset_statistics,
normalization_type=action_proprio_normalization_type,
),
num_parallel_calls,
)
return dataset, dataset_statistics
def apply_trajectory_transforms(
dataset: dl.DLataset,
*,
train: bool,
goal_relabeling_strategy: Optional[str] = None,
goal_relabeling_kwargs: dict = {},
window_size: int = 1,
future_action_window_size: int = 0,
subsample_length: Optional[int] = None,
skip_unlabeled: bool = False,
max_action: Optional[float] = None,
max_proprio: Optional[float] = None,
task_augment_strategy: Optional[str] = None,
task_augment_kwargs: dict = {},
num_parallel_calls: int = tf.data.AUTOTUNE,
) -> dl.DLataset:
"""
Applies common transforms that happen at a trajectory level. Such transforms are usually some sort of "relabeling"
(e.g., filtering, chunking, adding goals, dropping keys).
Transforms in this function should have the following properties:
- They require access to an entire trajectory (i.e., they cannot be applied frame-wise).
- They are generally not CPU-intensive, mostly involving moving and copying data.
- They do not require decoded images.
Args:
dataset (dl.DLataset): The dataset to transform.
train (bool): Whether the dataset is for training (affects subsampling).
goal_relabeling_strategy (str, optional): The goal relabeling strategy to use, or None for
no goal relabeling. See `goal_relabeling.py`.
goal_relabeling_kwargs (dict, optional): Additional keyword arguments to pass to the goal relabeling function.
window_size (int, optional): The length of the snippets that trajectories are chunked into.
future_action_window_size (int, optional): The number of future actions beyond window_size to include
in the chunked actions.
subsample_length (int, optional): If provided, trajectories longer than this will be subsampled to
this length (after goal relabeling and chunking).
skip_unlabeled (bool, optional): Whether to skip trajectories with no language labels.
max_action: (float, optional): If provided, trajectories in which *any* action dimension
of *any* transition has an absolute value larger than this will be skipped.
max_proprio: (float, optional): If provided, trajectories in which *any* proprio dimension
of *any* transition has an absolute value larger than this will be skipped.
task_augment_strategy (str, optional): The task augmentation strategy to use, or None for no task
augmentation. See `task_augmentation.py`.
task_augment_kwargs (dict, optional): Additional keyword arguments to pass to the task augmentation
function.
num_parallel_calls (int, optional): number of parallel calls for map operations. Default to AUTOTUNE.
"""
if skip_unlabeled:
if "language_instruction" not in dataset.element_spec["task"]:
raise ValueError("skip_unlabeled=True but dataset does not have language labels.")
dataset = dataset.filter(lambda x: tf.math.reduce_any(x["task"]["language_instruction"] != ""))
if max_action is not None:
dataset = dataset.filter(lambda x: tf.math.reduce_all(tf.math.abs(x["action"]) <= max_action))
if max_proprio is not None and "proprio" in dataset.element_spec["observation"]:
dataset = dataset.filter(lambda x: tf.math.reduce_all(tf.math.abs(x["observation"]["proprio"]) <= max_proprio))
# marks which entires of the observation and task dicts are padding
dataset = dataset.traj_map(traj_transforms.add_pad_mask_dict, num_parallel_calls)
# updates the "task" dict
if goal_relabeling_strategy is not None:
dataset = dataset.traj_map(
partial(getattr(goal_relabeling, goal_relabeling_strategy), **goal_relabeling_kwargs),
num_parallel_calls,
)
# must run task augmentation before chunking, in case it changes goal timesteps
if train and task_augment_strategy is not None:
# perform task augmentation (e.g., dropping keys)
dataset = dataset.traj_map(
partial(
getattr(task_augmentation, task_augment_strategy),
**task_augment_kwargs,
),
num_parallel_calls,
)
# chunks observations and actions, giving them a new axis at index 1 of size `window_size` and
# `window_size + future_action_window_size`, respectively
dataset = dataset.traj_map(
partial(
traj_transforms.chunk_act_obs,
window_size=window_size,
future_action_window_size=future_action_window_size,
),
num_parallel_calls,
)
if train and subsample_length is not None:
dataset = dataset.traj_map(
partial(traj_transforms.subsample, subsample_length=subsample_length),
num_parallel_calls,
)
return dataset
def apply_per_dataset_frame_transforms(
dataset: dl.DLataset,
chunk_filter_fn: Optional[Callable] = None,
):
"""
Optionally applied *per-dataset* transforms that happen at a frame level.
Args:
chunk_filter_fn (callable, optional): Filter function for chunks.
"""
if chunk_filter_fn:
dataset = dataset.filter(chunk_filter_fn)
return dataset
def apply_frame_transforms(
dataset: dl.DLataset,
*,
train: bool,
image_augment_kwargs: Union[Dict, Dict[str, Dict]] = {},
resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]] = {},
depth_resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]] = {},
num_parallel_calls: int = tf.data.AUTOTUNE,
) -> dl.DLataset:
"""
Applies common transforms that happen at a frame level. These transforms are usually more CPU-intensive, (e.g.,
decoding or resizing images).
Args:
train (bool): Whether the dataset is for training (affects image augmentation).
dataset (dl.DLataset): The dataset to transform.
image_augment_kwargs (dict|Mapping[str, dict]): Keyword arguments to pass to the image augmentation
function. See `dlimp.transforms.augment_image` for documentation of these kwargs. If a dict of
dicts is provided, then key "k" will be used for "image_{k}" (names determined by `image_obs_keys`
in `make_dataset_from_rlds`). Augmentation will be skipped for missing keys (so pass an empty dict
to skip augmentation for all images).
resize_size (Tuple[int, int]|Mapping[str, Tuple[int, int]]): If provided, images will be resized to
this size. If a dict of tuples is provided, then key "k" will be used for "image_{k}" (names
determined by `image_obs_keys` in `make_dataset_from_rlds`). Resizing will be skipped for missing
keys (so pass an empty dict to skip resizing for all images).
depth_resize_size (Tuple[int, int]|Mapping[str, Tuple[int, int]]): Same as resize_size, but for depth
images.
num_parallel_calls (int): number of parallel calls for frame_map operations. Default to AUTOTUNE.
"""
# Convenience wrapper that takes a function that operates on a non-chunked "observation" dict and applies
# it to the chunked "observation" dict as well as the non-chunked "task" dict
def apply_obs_transform(fn: Callable[[Dict], Dict], frame: Dict) -> Dict:
frame["task"] = fn(frame["task"])
frame["observation"] = dl.vmap(fn)(frame["observation"])
if 'observation_future' in frame:
frame["observation_future"] = dl.vmap(fn)(frame["observation_future"])
return frame
# Decode + resize images (and depth images)
dataset = dataset.frame_map(
partial(
apply_obs_transform,
partial(obs_transforms.decode_and_resize, resize_size=resize_size, depth_resize_size=depth_resize_size),
),
num_parallel_calls,
)
if train:
# Augment all images with the same seed, skipping padding images
def aug(frame: dict):
seed = tf.random.uniform([2], maxval=tf.dtypes.int32.max, dtype=tf.int32)
aug_fn = partial(obs_transforms.augment, seed=seed, augment_kwargs=image_augment_kwargs)
return apply_obs_transform(aug_fn, frame)
dataset = dataset.frame_map(aug, num_parallel_calls)
return dataset
def make_single_dataset(
dataset_kwargs: dict,
*,
train: bool,
traj_transform_kwargs: dict = {},
frame_transform_kwargs: dict = {},
) -> dl.DLataset:
"""Creates a single dataset from kwargs. Returns a dataset of trajectories.
Args:
dataset_kwargs: kwargs passed to `make_dataset_from_rlds` that are dataset-specific.
train: whether this is a training or validation dataset.
traj_transform_kwargs: kwargs passed to 'apply_trajectory_transforms'.
frame_transform_kwargs: kwargs passed to 'get_frame_transforms'.
"""
dataset, dataset_statistics = make_dataset_from_rlds(
**dataset_kwargs,
train=train,
)
dataset = apply_trajectory_transforms(dataset, **traj_transform_kwargs, train=train)
dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train)
# this seems to reduce memory usage without affecting speed
dataset = dataset.with_ram_budget(1)
# save for later
return dataset, dataset_statistics["num_trajectories"], dataset_statistics
# === Core Initializer ===
def make_interleaved_dataset(
dataset_kwargs_list: List[Dict],
sample_weights: Optional[List[float]] = None,
*,
train: bool,
shuffle_buffer_size: int,
traj_transform_kwargs: Optional[Dict] = None,
frame_transform_kwargs: Optional[Dict] = None,
batch_size: Optional[int] = None,
balance_weights: bool = False,
traj_transform_threads: Optional[int] = None,
traj_read_threads: Optional[int] = None,
) -> dl.DLataset:
"""
Creates an interleaved dataset from list of dataset configs (kwargs). Returns a dataset of batched frames.
Args:
dataset_kwargs_list: list of kwargs, each element of which is passed to `make_dataset_from_rlds`.
"num_parallel_calls" and "num_parallel_reads" are overridden using `traj_transform_threads` and
`traj_read_threads`, respectively.
sample_weights: sampling weights for each dataset in list. If None, defaults to uniform.
train: whether this is a training or validation dataset.
shuffle_buffer_size: size of the dataset shuffle buffer (in number of frames).
traj_transform_kwargs: kwargs passed to `apply_trajectory_transforms`. "num_parallel_calls" is
overridden using `traj_transform_threads`.
frame_transform_kwargs: kwargs passed to `apply_frame_transforms`.
batch_size: batch size, if not provided output is not batched.
balance_weights: if True, the sample weights are multiplied by the number of frames in each dataset.
This makes it so that, if all the sample weights are equal, one full iteration through the interleaved
dataset will correspond to one full iteration through each individual dataset (only in expectation,
since in practice the sampling is random).
traj_transform_threads: total number of parallel calls for trajectory transforms, distributed across
datasets according to their sampling weights. If None, defaults to AUTOTUNE for every dataset.
traj_read_threads: total number of parallel read workers for trajectory transforms, distributed across
datasets according to their sampling weights. If None, defaults to AUTOTUNE for every dataset.
"""
# Default to uniform sampling (if `sample_weights` is not specified)
if not sample_weights:
sample_weights = [1.0] * len(dataset_kwargs_list)
if len(sample_weights) != len(dataset_kwargs_list):
raise ValueError(f"sample_weights must be None or have length {len(dataset_kwargs_list)}.")
# Check valid `traj_transform_kwargs` and `frame_transform_kwargs`
if (traj_transform_kwargs is None) or (frame_transform_kwargs is None):
raise ValueError("Missing `traj_transform_kwargs` and `frame_transform_kwargs`!")
# Get Dataset Sizes
dataset_sizes, all_dataset_statistics = [], {}
for dataset_kwargs in dataset_kwargs_list:
data_kwargs = copy.deepcopy(dataset_kwargs)
if "dataset_frame_transform_kwargs" in data_kwargs:
data_kwargs.pop("dataset_frame_transform_kwargs")
_, dataset_statistics = make_dataset_from_rlds(**data_kwargs, train=train)
dataset_sizes.append(dataset_statistics["num_transitions"])
all_dataset_statistics[dataset_kwargs["name"]] = dataset_statistics
# Get the indices of the "primary" datasets (i.e., datasets with sample_weight == 1.0)
primary_dataset_indices = np.array([idx for idx in range(len(sample_weights)) if sample_weights[idx] == 1.0])
# Balance and Normalize Weights
if balance_weights:
sample_weights = np.array(sample_weights) * np.array(dataset_sizes)
sample_weights = np.array(sample_weights) / np.sum(sample_weights)
pprint_data_mixture(dataset_kwargs_list, sample_weights)
# Effective Dataset Length = Number of samples until each dataset has completed at least one epoch
# =>> Note :: Only counting the "primary" datasets (i.e., datasets with sample_weight == 1.0)
dataset_len = int((np.array(dataset_sizes) / sample_weights)[primary_dataset_indices].max())
# Allocate Threads based on Weights
threads_per_dataset = allocate_threads(traj_transform_threads, sample_weights)
reads_per_dataset = allocate_threads(traj_read_threads, sample_weights)
logger.info("Threads per Dataset: %s", threads_per_dataset)
logger.info("Reads per Dataset: %s", reads_per_dataset)
rank = dist.get_rank() # Rank of the current process
world_size = dist.get_world_size() # Total number of processes
# Construct Datasets
logger.info("Constructing datasets...")
datasets = []
for dataset_kwargs, threads, reads in zip(
dataset_kwargs_list,
threads_per_dataset,
reads_per_dataset,
):
dataset_frame_transform_kwargs = (
dataset_kwargs.pop("dataset_frame_transform_kwargs")
if "dataset_frame_transform_kwargs" in dataset_kwargs
else {}
)
dataset, _ = make_dataset_from_rlds(
**dataset_kwargs,
train=train,
num_parallel_calls=threads,
num_parallel_reads=reads,
dataset_statistics=all_dataset_statistics[dataset_kwargs["name"]],
)
# split dataset per process
dataset = partition_dataset(dataset, num_partitions=world_size, partition_index=rank)
print(f"Rank {rank} has {len(dataset)} samples")
dataset = apply_trajectory_transforms(
dataset.repeat(),
**traj_transform_kwargs,
num_parallel_calls=threads,
train=train,
).flatten(num_parallel_calls=threads)
dataset = apply_per_dataset_frame_transforms(dataset, **dataset_frame_transform_kwargs)
datasets.append(dataset)
# Interleave at the Frame Level
dataset: dl.DLataset = dl.DLataset.sample_from_datasets(datasets, sample_weights)
# Validation =>> fix a single shuffle buffer of data and cache it in RAM; prevents gradual memory increase!
if not train:
dataset = dataset.take(shuffle_buffer_size).cache()
# Shuffle the Dataset
# =>> IMPORTANT :: Shuffle AFTER .cache(), or else memory will still leak!
dataset = dataset.shuffle(shuffle_buffer_size)
# Apply Frame Transforms
logger.info("Applying frame transforms on dataset...")
dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train)
# [Contract] When training VLA Policies, we let the Collator handle Batching!
if batch_size is not None:
dataset = dataset.batch(batch_size)
# Note =>> Seems to reduce memory usage without affecting speed?
dataset = dataset.with_ram_budget(1)
# Save for Later
dataset.sample_weights = sample_weights
return dataset, dataset_len, all_dataset_statistics
"""
dataset.py
Core interface script for configuring and initializing RLDS datasets.
"""
import copy
import inspect
import json
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Union
import dlimp as dl
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from prismatic.logging import initialize_logging
from prismatic.vla.datasets.rlds import obs_transforms, traj_transforms
from prismatic.vla.datasets.rlds.utils import goal_relabeling, task_augmentation
from prismatic.vla.datasets.rlds.utils.data_utils import (
NormalizationType,
allocate_threads,
get_dataset_statistics,
normalize_action_and_proprio,
pprint_data_mixture,
tree_map,
)
# Initialize logging =>> Wraps `logging.Logger`
logging = initialize_logging(__name__)
# Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch)
tf.config.set_visible_devices([], "GPU")
# ruff: noqa: B006
def make_dataset_from_rlds(
name: str,
data_dir: str,
*,
train: bool,
standardize_fn: Optional[Callable[[dict], dict]] = None,
shuffle: bool = True,
image_obs_keys: Dict[str, Optional[str]] = {},
depth_obs_keys: Dict[str, Optional[str]] = {},
state_obs_keys: List[Optional[str]] = (),
language_key: Optional[str] = None,
action_proprio_normalization_type: NormalizationType = NormalizationType.NORMAL,
dataset_statistics: Optional[Union[dict, str]] = None,
absolute_action_mask: Optional[List[bool]] = None,
action_normalization_mask: Optional[List[bool]] = None,
num_parallel_reads: int = tf.data.AUTOTUNE,
num_parallel_calls: int = tf.data.AUTOTUNE,
future_action_window_size: int = 0, ## not used in this function, but used in apply_trajectory_transforms
) -> Tuple[dl.DLataset, dict]:
"""
This function is responsible for loading a specific RLDS dataset from storage and getting it into a standardized
format. Yields a dataset of trajectories. Does not include CPU-intensive operations.
If `standardize_fn` is provided, it will be applied to each trajectory. This function should get the trajectory
into a standard format, which includes the keys "observation" and "action". Entry "observation" should be a
dictionary containing some number of additional keys, which will be extracted into an even more standardized format
according to the "*_obs_keys" arguments.
The `image_obs_keys` and `depth_obs_keys` arguments are mappings from new names to old names, or None in place of an
old name to insert padding. For example, if after `standardize_fn`, your "observation" dict has RGB images called
"workspace" and "wrist", and `image_obs_keys={"primary": "workspace", "secondary": None, "wrist": "wrist"}`, then
the resulting dataset will have an "observation" dict containing the keys "image_primary", "image_secondary", and
"image_wrist", where "image_primary" corresponds to "workspace", "image_secondary" is a padding image, and
"image_wrist" corresponds to "wrist".
Entry `state_obs_keys` is a list of 1-dimensional proprioceptive keys to concatenate into a single array, which will
be placed in the "proprio" key of the "observation" dict. A single padding element (zero) will be inserted for each
None entry.
The dataset will also include a "task" dict. If `language_key` is provided, then the "task" dict will contain the
key "language_instruction", extracted from `traj[language_key]`.
Args:
name (str): The name of the RLDS dataset (usually "name" or "name:version").
data_dir (str): The path to the data directory.
train (bool): Whether to use the training or validation split.
shuffle (bool, optional): Whether to shuffle the file read order (does NOT fully shuffle the dataset, since one
file usually contains many trajectories)!
standardize_fn (Callable[[dict], dict], optional): A function that, if provided, will be the first
thing applied to each trajectory.
image_obs_keys (Mapping[str, str|None]): Mapping from {new: old} indicating which RGB images to extract from the
"observation" dict. `new_obs = {f"image_{new}": old_obs[old] for new, old in image_obs_keys.items()}`.
If a value of `old` is None, inserts a padding image instead (empty string).
depth_obs_keys (Mapping[str, str|None]): Same as `image_obs_keys`, but for depth images. Keys will be
prefixed with "depth_" instead of "image_".
state_obs_keys (Sequence[str|None]): List of 1-dimensional proprioception keys to be extracted from the
"observation" dict, concatenated, and mapped to "proprio". Inserts 1 element of padding for each None entry.
language_key (str, optional): If provided, the "task" dict will contain the key "language_instruction",
extracted from `traj[language_key]`.
action_proprio_normalization_type (str, optional): The type of normalization to perform on the action,
proprio, or both. Can be "normal" (mean 0, std 1) or "bounds" (normalized to [-1, 1]).
dataset_statistics: (dict|str, optional): dict (or path to JSON file) that contains dataset statistics
for normalization. If `action_proprio_normalization_type` is "normal", this should contain "mean" and
"std" keys. If `action_proprio_normalization_type` is "bounds", this should contain "min" and "max"
keys. May also provide "num_transitions" and "num_trajectories" keys for downstream usage (e.g., for
`make_interleaved_dataset`). If not provided, the statistics will be computed on the fly.
absolute_action_mask (Sequence[bool], optional): By default, all action dimensions are assumed to be
relative. This is important for when `future_action_window_size > 0`: actions that are taken
from beyond the end of the trajectory (or beyond the goal timestep when goal relabeling is used)
need to be made "neutral" to indicate that the task has been completed. For relative actions,
"neutral" means zero, but for absolute actions, "neutral" means repeating the last valid action.
This mask, if provided, indicates which action dimensions are absolute.
action_normalization_mask (Sequence[bool], optional): If provided, indicates which action dimensions
should be normalized. For example, you might not want to normalize the gripper action dimension if
it's always exactly 0 or 1. By default, all action dimensions are normalized.
num_parallel_reads (int): number of parallel read workers. Default to AUTOTUNE.
num_parallel_calls (int): number of parallel calls for traj_map operations. Default to AUTOTUNE.
Returns:
Dataset of trajectories where each step has the following fields:
- observation:
- image_{name1, name2, ...} # RGB image observations
- depth_{name1, name2, ...} # depth image observations
- proprio # 1-dimensional array of proprioceptive observations
- timestep # timestep of each frame
- task:
- language_instruction # language instruction, present if `language_key` is provided
- action # action vector
- dataset_name # name of the dataset
"""
print("dataset name: ", name)
print("window size: ", future_action_window_size)
REQUIRED_KEYS = {"observation", "action"}
if language_key is not None:
REQUIRED_KEYS.add(language_key)
def restructure(traj):
# apply a standardization function, if provided
if standardize_fn is not None:
traj = standardize_fn(traj)
if not all(k in traj for k in REQUIRED_KEYS):
raise ValueError(
f"Trajectory is missing keys: {REQUIRED_KEYS - set(traj.keys())}. " "Did you write a `standardize_fn`?"
)
# extracts images, depth images and proprio from the "observation" dict
traj_len = tf.shape(traj["action"])[0]
old_obs = traj["observation"]
new_obs = {}
for new, old in image_obs_keys.items():
if old is None:
new_obs[f"image_{new}"] = tf.repeat("", traj_len) # padding
else:
new_obs[f"image_{new}"] = old_obs[old]
for new, old in depth_obs_keys.items():
if old is None:
new_obs[f"depth_{new}"] = tf.repeat("", traj_len) # padding
else:
new_obs[f"depth_{new}"] = old_obs[old]
if state_obs_keys:
new_obs["proprio"] = tf.concat(
[
(
tf.zeros((traj_len, 1), dtype=tf.float32) # padding
if key is None
else tf.cast(old_obs[key], tf.float32)
)
for key in state_obs_keys
],
axis=1,
)
# add timestep info
new_obs["timestep"] = tf.range(traj_len)
# extracts `language_key` into the "task" dict
task = {}
if language_key is not None:
if traj[language_key].dtype != tf.string:
raise ValueError(
f"Language key {language_key} has dtype {traj[language_key].dtype}, " "but it must be tf.string."
)
task["language_instruction"] = traj.pop(language_key)
traj = {
"observation": new_obs,
"task": task,
"action": tf.cast(traj["action"], tf.float32),
"dataset_name": tf.repeat(name, traj_len),
}
if absolute_action_mask is not None:
if len(absolute_action_mask) != traj["action"].shape[-1]:
raise ValueError(
f"Length of absolute_action_mask ({len(absolute_action_mask)}) "
f"does not match action dimension ({traj['action'].shape[-1]})."
)
traj["absolute_action_mask"] = tf.tile(
tf.convert_to_tensor(absolute_action_mask, dtype=tf.bool)[None],
[traj_len, 1],
)
return traj
builder = tfds.builder(name, data_dir=data_dir)
# load or compute dataset statistics
if isinstance(dataset_statistics, str):
with tf.io.gfile.GFile(dataset_statistics, "r") as f:
dataset_statistics = json.load(f)
elif dataset_statistics is None:
full_dataset = dl.DLataset.from_rlds(
builder, split="all", shuffle=False, num_parallel_reads=num_parallel_reads
).traj_map(restructure, num_parallel_calls)
# tries to load from cache, otherwise computes on the fly
dataset_statistics = get_dataset_statistics(
full_dataset,
hash_dependencies=(
str(builder.info),
str(state_obs_keys),
inspect.getsource(standardize_fn) if standardize_fn is not None else "",
),
save_dir=builder.data_dir,
)
dataset_statistics = tree_map(np.array, dataset_statistics)
# skip normalization for certain action dimensions
if action_normalization_mask is not None:
if len(action_normalization_mask) != dataset_statistics["action"]["mean"].shape[-1]:
raise ValueError(
f"Length of skip_normalization_mask ({len(action_normalization_mask)}) "
f"does not match action dimension ({dataset_statistics['action']['mean'].shape[-1]})."
)
dataset_statistics["action"]["mask"] = np.array(action_normalization_mask)
# construct the dataset
if "val" not in builder.info.splits:
split = "train[:95%]" if train else "train[95%:]"
else:
split = "train" if train else "val"
dataset = dl.DLataset.from_rlds(builder, split=split, shuffle=shuffle, num_parallel_reads=num_parallel_reads)
dataset = dataset.traj_map(restructure, num_parallel_calls)
dataset = dataset.traj_map(
partial(
normalize_action_and_proprio,
metadata=dataset_statistics,
normalization_type=action_proprio_normalization_type,
),
num_parallel_calls,
)
return dataset, dataset_statistics
def apply_trajectory_transforms(
dataset: dl.DLataset,
*,
train: bool,
goal_relabeling_strategy: Optional[str] = None,
goal_relabeling_kwargs: dict = {},
window_size: int = 1,
future_action_window_size: int = 0,
subsample_length: Optional[int] = None,
skip_unlabeled: bool = False,
max_action: Optional[float] = None,
max_proprio: Optional[float] = None,
task_augment_strategy: Optional[str] = None,
task_augment_kwargs: dict = {},
num_parallel_calls: int = tf.data.AUTOTUNE,
latent: bool = False,
) -> dl.DLataset:
"""
Applies common transforms that happen at a trajectory level. Such transforms are usually some sort of "relabeling"
(e.g., filtering, chunking, adding goals, dropping keys).
Transforms in this function should have the following properties:
- They require access to an entire trajectory (i.e., they cannot be applied frame-wise).
- They are generally not CPU-intensive, mostly involving moving and copying data.
- They do not require decoded images.
Args:
dataset (dl.DLataset): The dataset to transform.
train (bool): Whether the dataset is for training (affects subsampling).
goal_relabeling_strategy (str, optional): The goal relabeling strategy to use, or None for
no goal relabeling. See `goal_relabeling.py`.
goal_relabeling_kwargs (dict, optional): Additional keyword arguments to pass to the goal relabeling function.
window_size (int, optional): The length of the snippets that trajectories are chunked into.
future_action_window_size (int, optional): The number of future actions beyond window_size to include
in the chunked actions.
subsample_length (int, optional): If provided, trajectories longer than this will be subsampled to
this length (after goal relabeling and chunking).
skip_unlabeled (bool, optional): Whether to skip trajectories with no language labels.
max_action: (float, optional): If provided, trajectories in which *any* action dimension
of *any* transition has an absolute value larger than this will be skipped.
max_proprio: (float, optional): If provided, trajectories in which *any* proprio dimension
of *any* transition has an absolute value larger than this will be skipped.
task_augment_strategy (str, optional): The task augmentation strategy to use, or None for no task
augmentation. See `task_augmentation.py`.
task_augment_kwargs (dict, optional): Additional keyword arguments to pass to the task augmentation
function.
num_parallel_calls (int, optional): number of parallel calls for map operations. Default to AUTOTUNE.
"""
if skip_unlabeled:
if "language_instruction" not in dataset.element_spec["task"]:
raise ValueError("skip_unlabeled=True but dataset does not have language labels.")
dataset = dataset.filter(lambda x: tf.math.reduce_any(x["task"]["language_instruction"] != ""))
if max_action is not None:
dataset = dataset.filter(lambda x: tf.math.reduce_all(tf.math.abs(x["action"]) <= max_action))
if max_proprio is not None and "proprio" in dataset.element_spec["observation"]:
dataset = dataset.filter(lambda x: tf.math.reduce_all(tf.math.abs(x["observation"]["proprio"]) <= max_proprio))
# marks which entires of the observation and task dicts are padding
dataset = dataset.traj_map(traj_transforms.add_pad_mask_dict, num_parallel_calls)
# updates the "task" dict
if goal_relabeling_strategy is not None:
dataset = dataset.traj_map(
partial(getattr(goal_relabeling, goal_relabeling_strategy), **goal_relabeling_kwargs),
num_parallel_calls,
)
# must run task augmentation before chunking, in case it changes goal timesteps
if train and task_augment_strategy is not None:
# perform task augmentation (e.g., dropping keys)
dataset = dataset.traj_map(
partial(
getattr(task_augmentation, task_augment_strategy),
**task_augment_kwargs,
),
num_parallel_calls,
)
if latent:
dataset = dataset.traj_map(
partial(
traj_transforms.chunk_act_obs_latent,
window_size=window_size,
future_action_window_size=future_action_window_size,
),
num_parallel_calls,
)
else:
# chunks observations and actions, giving them a new axis at index 1 of size `window_size` and
# `window_size + future_action_window_size`, respectively
dataset = dataset.traj_map(
partial(
traj_transforms.chunk_act_obs,
window_size=window_size,
future_action_window_size=future_action_window_size,
),
num_parallel_calls,
)
if train and subsample_length is not None:
dataset = dataset.traj_map(
partial(traj_transforms.subsample, subsample_length=subsample_length),
num_parallel_calls,
)
return dataset
def apply_per_dataset_frame_transforms(
dataset: dl.DLataset,
chunk_filter_fn: Optional[Callable] = None,
):
"""
Optionally applied *per-dataset* transforms that happen at a frame level.
Args:
chunk_filter_fn (callable, optional): Filter function for chunks.
"""
if chunk_filter_fn:
dataset = dataset.filter(chunk_filter_fn)
return dataset
def apply_frame_transforms(
dataset: dl.DLataset,
*,
train: bool,
image_augment_kwargs: Union[Dict, Dict[str, Dict]] = {},
resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]] = {},
depth_resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]] = {},
num_parallel_calls: int = tf.data.AUTOTUNE,
latent: bool = False,
) -> dl.DLataset:
"""
Applies common transforms that happen at a frame level. These transforms are usually more CPU-intensive, (e.g.,
decoding or resizing images).
Args:
train (bool): Whether the dataset is for training (affects image augmentation).
dataset (dl.DLataset): The dataset to transform.
image_augment_kwargs (dict|Mapping[str, dict]): Keyword arguments to pass to the image augmentation
function. See `dlimp.transforms.augment_image` for documentation of these kwargs. If a dict of
dicts is provided, then key "k" will be used for "image_{k}" (names determined by `image_obs_keys`
in `make_dataset_from_rlds`). Augmentation will be skipped for missing keys (so pass an empty dict
to skip augmentation for all images).
resize_size (Tuple[int, int]|Mapping[str, Tuple[int, int]]): If provided, images will be resized to
this size. If a dict of tuples is provided, then key "k" will be used for "image_{k}" (names
determined by `image_obs_keys` in `make_dataset_from_rlds`). Resizing will be skipped for missing
keys (so pass an empty dict to skip resizing for all images).
depth_resize_size (Tuple[int, int]|Mapping[str, Tuple[int, int]]): Same as resize_size, but for depth
images.
num_parallel_calls (int): number of parallel calls for frame_map operations. Default to AUTOTUNE.
"""
# Convenience wrapper that takes a function that operates on a non-chunked "observation" dict and applies
# it to the chunked "observation" dict as well as the non-chunked "task" dict
def apply_obs_transform(fn: Callable[[Dict], Dict], frame: Dict) -> Dict:
frame["task"] = fn(frame["task"])
frame["observation"] = dl.vmap(fn)(frame["observation"])
return frame
def apply_obs_transform_latent(fn: Callable[[Dict], Dict], frame: Dict) -> Dict:
frame["observation_latent_curr"] = dl.vmap(fn)(frame["observation_latent_curr"])
frame["observation_latent_next"] = dl.vmap(fn)(frame["observation_latent_next"])
return frame
# Decode + resize images (and depth images)
dataset = dataset.frame_map(
partial(
apply_obs_transform,
partial(obs_transforms.decode_and_resize, resize_size=resize_size, depth_resize_size=depth_resize_size),
),
num_parallel_calls,
)
if latent:
dataset = dataset.frame_map(
partial(
apply_obs_transform_latent,
# partial(obs_transforms.latent_labeling, resize_size=resize_size, depth_resize_size=depth_resize_size, codebook_size=codebook_size, codebook_seq = codebook_seq, ckpt_path = ckpt_path),
partial(obs_transforms.decode_and_resize, resize_size=resize_size, depth_resize_size=depth_resize_size),
),
num_parallel_calls,
)
if train:
# Augment all images with the same seed, skipping padding images
def aug(frame: dict):
seed = tf.random.uniform([2], maxval=tf.dtypes.int32.max, dtype=tf.int32)
aug_fn = partial(obs_transforms.augment, seed=seed, augment_kwargs=image_augment_kwargs)
return apply_obs_transform(aug_fn, frame)
dataset = dataset.frame_map(aug, num_parallel_calls)
return dataset
def make_single_dataset(
dataset_kwargs: dict,
*,
train: bool,
traj_transform_kwargs: dict = {},
frame_transform_kwargs: dict = {},
) -> dl.DLataset:
"""Creates a single dataset from kwargs. Returns a dataset of trajectories.
Args:
dataset_kwargs: kwargs passed to `make_dataset_from_rlds` that are dataset-specific.
train: whether this is a training or validation dataset.
traj_transform_kwargs: kwargs passed to 'apply_trajectory_transforms'.
frame_transform_kwargs: kwargs passed to 'get_frame_transforms'.
"""
dataset, dataset_statistics = make_dataset_from_rlds(
**dataset_kwargs,
train=train,
)
dataset = apply_trajectory_transforms(dataset, **traj_transform_kwargs, train=train)
dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train)
# this seems to reduce memory usage without affecting speed
dataset = dataset.with_ram_budget(1)
# save for later
return dataset, dataset_statistics["num_trajectories"], dataset_statistics
# === Core Initializer ===
def make_interleaved_dataset(
dataset_kwargs_list: List[Dict],
sample_weights: Optional[List[float]] = None,
*,
train: bool,
shuffle_buffer_size: int,
traj_transform_kwargs: Optional[Dict] = None,
frame_transform_kwargs: Optional[Dict] = None,
batch_size: Optional[int] = None,
balance_weights: bool = False,
traj_transform_threads: Optional[int] = None,
traj_read_threads: Optional[int] = None,
) -> dl.DLataset:
"""
Creates an interleaved dataset from list of dataset configs (kwargs). Returns a dataset of batched frames.
Args:
dataset_kwargs_list: list of kwargs, each element of which is passed to `make_dataset_from_rlds`.
"num_parallel_calls" and "num_parallel_reads" are overridden using `traj_transform_threads` and
`traj_read_threads`, respectively.
sample_weights: sampling weights for each dataset in list. If None, defaults to uniform.
train: whether this is a training or validation dataset.
shuffle_buffer_size: size of the dataset shuffle buffer (in number of frames).
traj_transform_kwargs: kwargs passed to `apply_trajectory_transforms`. "num_parallel_calls" is
overridden using `traj_transform_threads`.
frame_transform_kwargs: kwargs passed to `apply_frame_transforms`.
batch_size: batch size, if not provided output is not batched.
balance_weights: if True, the sample weights are multiplied by the number of frames in each dataset.
This makes it so that, if all the sample weights are equal, one full iteration through the interleaved
dataset will correspond to one full iteration through each individual dataset (only in expectation,
since in practice the sampling is random).
traj_transform_threads: total number of parallel calls for trajectory transforms, distributed across
datasets according to their sampling weights. If None, defaults to AUTOTUNE for every dataset.
traj_read_threads: total number of parallel read workers for trajectory transforms, distributed across
datasets according to their sampling weights. If None, defaults to AUTOTUNE for every dataset.
"""
# Default to uniform sampling (if `sample_weights` is not specified)
if not sample_weights:
sample_weights = [1.0] * len(dataset_kwargs_list)
if len(sample_weights) != len(dataset_kwargs_list):
raise ValueError(f"sample_weights must be None or have length {len(dataset_kwargs_list)}.")
# Check valid `traj_transform_kwargs` and `frame_transform_kwargs`
if (traj_transform_kwargs is None) or (frame_transform_kwargs is None):
raise ValueError("Missing `traj_transform_kwargs` and `frame_transform_kwargs`!")
# Get Dataset Sizes
dataset_sizes, all_dataset_statistics = [], {}
for dataset_kwargs in dataset_kwargs_list:
data_kwargs = copy.deepcopy(dataset_kwargs)
if "dataset_frame_transform_kwargs" in data_kwargs:
data_kwargs.pop("dataset_frame_transform_kwargs")
_, dataset_statistics = make_dataset_from_rlds(**data_kwargs, train=train)
dataset_sizes.append(dataset_statistics["num_transitions"])
all_dataset_statistics[dataset_kwargs["name"]] = dataset_statistics
# Get the indices of the "primary" datasets (i.e., datasets with sample_weight == 1.0)
primary_dataset_indices = np.array([idx for idx in range(len(sample_weights)) if sample_weights[idx] == 1.0])
# Balance and Normalize Weights
if balance_weights:
sample_weights = np.array(sample_weights) * np.array(dataset_sizes)
sample_weights = np.array(sample_weights) / np.sum(sample_weights)
pprint_data_mixture(dataset_kwargs_list, sample_weights)
# Effective Dataset Length = Number of samples until each dataset has completed at least one epoch
# =>> Note :: Only counting the "primary" datasets (i.e., datasets with sample_weight == 1.0)
dataset_len = int((np.array(dataset_sizes) / sample_weights)[primary_dataset_indices].max())
# Allocate Threads based on Weights
threads_per_dataset = allocate_threads(traj_transform_threads, sample_weights)
reads_per_dataset = allocate_threads(traj_read_threads, sample_weights)
logging.info("Threads per Dataset: %s", threads_per_dataset)
logging.info("Reads per Dataset: %s", reads_per_dataset)
# Construct Datasets
logging.info("Constructing datasets...")
datasets = []
for dataset_kwargs, threads, reads in zip(
dataset_kwargs_list,
threads_per_dataset,
reads_per_dataset,
):
dataset_frame_transform_kwargs = (
dataset_kwargs.pop("dataset_frame_transform_kwargs")
if "dataset_frame_transform_kwargs" in dataset_kwargs
else {}
)
dataset, _ = make_dataset_from_rlds(
**dataset_kwargs,
train=train,
num_parallel_calls=threads,
num_parallel_reads=reads,
dataset_statistics=all_dataset_statistics[dataset_kwargs["name"]],
)
dataset = apply_trajectory_transforms(
dataset.repeat(),
**traj_transform_kwargs,
future_action_window_size=dataset_kwargs["future_action_window_size"],
num_parallel_calls=threads,
train=train,
).flatten(num_parallel_calls=threads)
dataset = apply_per_dataset_frame_transforms(dataset, **dataset_frame_transform_kwargs)
datasets.append(dataset)
# Interleave at the Frame Level
dataset: dl.DLataset = dl.DLataset.sample_from_datasets(datasets, sample_weights)
# Validation =>> fix a single shuffle buffer of data and cache it in RAM; prevents gradual memory increase!
if not train:
dataset = dataset.take(shuffle_buffer_size).cache()
# Shuffle the Dataset
# =>> IMPORTANT :: Shuffle AFTER .cache(), or else memory will still leak!
dataset = dataset.shuffle(shuffle_buffer_size)
# Apply Frame Transforms
logging.info("Applying frame transforms on dataset...")
dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train)
# [Contract] When training VLA Policies, we let the Collator handle Batching!
if batch_size is not None:
dataset = dataset.batch(batch_size)
# Note =>> Seems to reduce memory usage without affecting speed?
dataset = dataset.with_ram_budget(1)
# Save for Later
dataset.sample_weights = sample_weights
return dataset, dataset_len, all_dataset_statistics
"""
obs_transforms.py
Contains observation-level transforms used in the orca data pipeline.
These transforms operate on the "observation" dictionary, and are applied at a per-frame level.
"""
from typing import Dict, Tuple, Union
import dlimp as dl
import tensorflow as tf
from absl import logging
# ruff: noqa: B023
def augment(obs: Dict, seed: tf.Tensor, augment_kwargs: Union[Dict, Dict[str, Dict]]) -> Dict:
"""Augments images, skipping padding images."""
image_names = {key[6:] for key in obs if key.startswith("image_")}
# "augment_order" is required in augment_kwargs, so if it's there, we can assume that the user has passed
# in a single augmentation dict (otherwise, we assume that the user has passed in a mapping from image
# name to augmentation dict)
if "augment_order" in augment_kwargs:
augment_kwargs = {name: augment_kwargs for name in image_names}
for i, name in enumerate(image_names):
if name not in augment_kwargs:
continue
kwargs = augment_kwargs[name]
logging.debug(f"Augmenting image_{name} with kwargs {kwargs}")
obs[f"image_{name}"] = tf.cond(
obs["pad_mask_dict"][f"image_{name}"],
lambda: dl.transforms.augment_image(
obs[f"image_{name}"],
**kwargs,
seed=seed + i, # augment each image differently
),
lambda: obs[f"image_{name}"], # skip padding images
)
return obs
def decode_and_resize(
obs: Dict,
resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]],
depth_resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]],
) -> Dict:
"""Decodes images and depth images, and then optionally resizes them."""
image_names = {key[6:] for key in obs if key.startswith("image_")}
depth_names = {key[6:] for key in obs if key.startswith("depth_")}
if isinstance(resize_size, tuple):
resize_size = {name: resize_size for name in image_names}
if isinstance(depth_resize_size, tuple):
depth_resize_size = {name: depth_resize_size for name in depth_names}
for name in image_names:
if name not in resize_size:
logging.warning(
f"No resize_size was provided for image_{name}. This will result in 1x1 "
"padding images, which may cause errors if you mix padding and non-padding images."
)
image = obs[f"image_{name}"]
if image.dtype == tf.string:
if tf.strings.length(image) == 0:
# this is a padding image
image = tf.zeros((*resize_size.get(name, (1, 1)), 3), dtype=tf.uint8)
else:
image = tf.io.decode_image(image, expand_animations=False, dtype=tf.uint8)
elif image.dtype != tf.uint8:
raise ValueError(f"Unsupported image dtype: found image_{name} with dtype {image.dtype}")
if name in resize_size:
image = dl.transforms.resize_image(image, size=resize_size[name])
obs[f"image_{name}"] = image
for name in depth_names:
if name not in depth_resize_size:
logging.warning(
f"No depth_resize_size was provided for depth_{name}. This will result in 1x1 "
"padding depth images, which may cause errors if you mix padding and non-padding images."
)
depth = obs[f"depth_{name}"]
if depth.dtype == tf.string:
if tf.strings.length(depth) == 0:
depth = tf.zeros((*depth_resize_size.get(name, (1, 1)), 1), dtype=tf.float32)
else:
depth = tf.io.decode_image(depth, expand_animations=False, dtype=tf.float32)[..., 0]
elif depth.dtype != tf.float32:
raise ValueError(f"Unsupported depth dtype: found depth_{name} with dtype {depth.dtype}")
if name in depth_resize_size:
depth = dl.transforms.resize_depth_image(depth, size=depth_resize_size[name])
obs[f"depth_{name}"] = depth
return obs
from .materialize import get_oxe_dataset_kwargs_and_weights
from .mixtures import OXE_NAMED_MIXTURES
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment