""" materialize.py Factory class for initializing Open-X RLDS-backed datasets, given specified data mixture parameters; provides and exports individual functions for clear control flow. """ from pathlib import Path from typing import Tuple, Type, Dict, Sequence from dataclasses import dataclass, field from torch.utils.data import Dataset from transformers import PreTrainedTokenizerBase import torch from .action_tokenizer import ActionTokenizer from .datasets import EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset @dataclass class PaddedCollatorForLanguageModeling: model_max_length: int pad_token_id: int default_image_resolution: Tuple[int, int, int] padding_side: str = "right" pixel_values_dtype: torch.dtype = torch.float32 def __post_init__(self) -> None: self.dummy_pixel_values = torch.zeros(self.default_image_resolution, dtype=self.pixel_values_dtype) def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) pixel_values = [instance["pixel_values"] for instance in instances] pixel_values_future = [instance["pixel_values_future"] for instance in instances] # For now, we only support Tokenizers with `padding_side = "right"` during Training (but plan to extend!) # => Handle padding via RNN Utils => `pad_sequence` input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id) labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) # Truncate (if necessary) input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length] # Get `attention_mask` by checking for `pad_token_id` attention_mask = input_ids.ne(self.pad_token_id) # === Handle "unimodal" (language-only) vs. "multimodal" === # Some examples are "language-only" --> build a Tensor of `multimodal_indices` that we can slice into easily multimodal_indices = torch.tensor( [idx for idx in range(len(pixel_values)) if pixel_values[idx] is not None], dtype=torch.long ) # Stack all `pixel_values` --> depending on type (torch.Tensor, or Dict[str, torch.Tensor]) & presence of None if len(multimodal_indices) == 0: pixel_values = torch.stack([self.dummy_pixel_values for _ in range(len(input_ids))]) elif isinstance(pv_example := pixel_values[multimodal_indices[0]], torch.Tensor): pixel_values = torch.stack( [ pixel_values[idx] if idx in multimodal_indices else self.dummy_pixel_values for idx in range(len(input_ids)) ] ) elif isinstance(pv_example, dict): pixel_values = { k: torch.stack( [ pixel_values[idx][k] if idx in multimodal_indices else self.dummy_pixel_values for idx in range(len(input_ids)) ] ) for k in pv_example } else: raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") return dict( pixel_values=pixel_values, pixel_values_future=pixel_values_future, input_ids=input_ids, attention_mask=attention_mask, labels=labels, multimodal_indices=multimodal_indices, ) @dataclass class PaddedCollatorForActionPrediction: model_max_length: int pad_token_id: int padding_side: str = "right" pixel_values_dtype: torch.dtype = torch.float32 def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) pixel_values = [instance["pixel_values"] for instance in instances] pixel_values_future = [instance["pixel_values_future"] for instance in instances] if "dataset_name" in instances[0]: dataset_names = [instance["dataset_name"] for instance in instances] else: dataset_names = None # For now, we only support Tokenizers with `padding_side = "right"` during training # => Handle padding via RNN Utils => `pad_sequence` assert self.padding_side == "right", f"Invalid Tokenizer `{self.padding_side = }`" input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id) labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) # Truncate (if necessary) input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length] # Get `attention_mask` by checking for `pad_token_id` attention_mask = input_ids.ne(self.pad_token_id) # [Contract] For VLA Training =>> No "Unimodal" Data! assert all([pv is not None for pv in pixel_values]), "Invalid VLA Example with `pixel_values = None`!" # Stack all `pixel_values` --> depending on type is torch.Tensor or Dict[str, torch.Tensor] if isinstance(pixel_values[0], torch.Tensor): pixel_values = torch.stack(pixel_values) elif isinstance(pixel_values[0], dict): pixel_values = { k: torch.stack([pixel_values[idx][k] for idx in range(len(input_ids))]) for k in pixel_values[0] } else: raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") if isinstance(pixel_values_future[0], torch.Tensor): pixel_values_future = torch.stack(pixel_values_future) elif isinstance(pixel_values_future[0], dict): pixel_values_future = { k: torch.stack([pixel_values_future[idx][k] for idx in range(len(input_ids))]) for k in pixel_values_future[0] } else: raise ValueError(f"Unsupported `pixel_values_future` type = {type(pixel_values_future)}") output = dict( pixel_values=pixel_values, pixel_values_future=pixel_values_future, input_ids=input_ids, attention_mask=attention_mask, labels=labels, ) if dataset_names is not None: output["dataset_names"] = dataset_names return output @dataclass class PaddedCollatorForEpisodeActionPrediction: model_max_length: int pad_token_id: int padding_side: str = "right" pixel_values_dtype: torch.dtype = torch.float32 def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) pixel_values = [instance["pixel_values"] for instance in instances] if "dataset_name" in instances[0]: dataset_names = [instance["dataset_name"] for instance in instances] else: dataset_names = None # For now, we only support Tokenizers with `padding_side = "right"` during training # => Handle padding via RNN Utils => `pad_sequence` assert self.padding_side == "right", f"Invalid Tokenizer `{self.padding_side = }`" input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id) labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) # Truncate (if necessary) input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length] # Get `attention_mask` by checking for `pad_token_id` attention_mask = input_ids.ne(self.pad_token_id) # [Contract] For VLA Training =>> No "Unimodal" Data! assert all([pv is not None for pv in pixel_values]), "Invalid VLA Example with `pixel_values = None`!" # Stack all `pixel_values` --> depending on type is torch.Tensor or Dict[str, torch.Tensor] if isinstance(pixel_values[0], torch.Tensor): pixel_values = torch.stack(pixel_values) elif isinstance(pixel_values[0], dict): pixel_values = { k: torch.stack([pixel_values[idx][k] for idx in range(len(input_ids))]) for k in pixel_values[0] } else: raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") output = dict( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=labels, ) if dataset_names is not None: output["dataset_names"] = dataset_names return output def get_vla_dataset_and_collator( data_root_dir: Path, data_mix: str, image_transform: None, # ImageTransform, visual_tracker: None, dataset_settings: None, tokenizer: PreTrainedTokenizerBase, prompt_builder_fn: None, # Type[PromptBuilder], default_image_resolution: Tuple[int, int, int], padding_side: str = "right", predict_stop_token: bool = True, shuffle_buffer_size: int = 100_000, train: bool = True, episodic: bool = False, image_aug: bool = False, future_action_window_size: int = 0, local_run: bool = False, ) -> Tuple[Dataset, ActionTokenizer, PaddedCollatorForActionPrediction]: """Initialize RLDS Dataset (wraps TFDS), ActionTokenizer, and initialize transform/collation functions.""" action_tokenizer = ActionTokenizer(tokenizer) batch_transform = RLDSBatchTransform( action_tokenizer, tokenizer, image_transform, prompt_builder_fn, visual_tracker, dataset_settings, data_root_dir, predict_stop_token=predict_stop_token, local_run=local_run ) collator = PaddedCollatorForActionPrediction( tokenizer.model_max_length, tokenizer.pad_token_id, padding_side=padding_side ) # Build RLDS Iterable Dataset cls = RLDSDataset if not episodic else EpisodicRLDSDataset dataset = cls( data_root_dir, data_mix, batch_transform, resize_resolution=default_image_resolution[1:], shuffle_buffer_size=shuffle_buffer_size, train=train, image_aug=image_aug, future_action_window_size=future_action_window_size, ) return dataset, action_tokenizer, collator