Commit 0063a668 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
# Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,intellij+all,vim
# Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,intellij+all,vim
### Intellij+all ###
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf
# AWS User-specific
.idea/**/aws.xml
# Generated files
.idea/**/contentModel.xml
# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml
# Gradle
.idea/**/gradle.xml
.idea/**/libraries
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
cmake-build-*/
# Mongo Explorer plugin
.idea/**/mongoSettings.xml
# File-based project format
*.iws
# IntelliJ
out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml
# SonarLint plugin
.idea/sonarlint/
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Editor-based Rest Client
.idea/httpRequests
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
### Intellij+all Patch ###
# Ignore everything but code style settings and run configurations
# that are supposed to be shared within teams.
.idea/*
!.idea/codeStyles
!.idea/runConfigurations
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
### Vim ###
# Swap
[._]*.s[a-v][a-z]
!*.svg # comment out if you don't need vector files
[._]*.sw[a-p]
[._]s[a-rt-v][a-z]
[._]ss[a-gi-z]
[._]sw[a-p]
# Session
Session.vim
Sessionx.vim
# Temporary
.netrwhist
*~
# Auto-generated tag files
tags
# Persistent undo
[._]*.un~
### VisualStudioCode ###
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets
# Local History for Visual Studio Code
.history/
# Built Visual Studio Code Extensions
*.vsix
### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide
# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,intellij+all,vim
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: check-ast
- id: check-added-large-files
- id: check-case-conflict
- id: check-merge-conflict
- id: end-of-file-fixer
- id: trailing-whitespace
- id: detect-private-key
- id: debug-statements
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
exclude: __init__.py
args: ["--profile", "black"]
# DLIMP
Dataloading is my passion.
## Installation
Requires Python >= 3.8.
```bash
git clone https://github.com/kvablack/dlimp
cd dlimp
pip install -e .
```
## Usage
Core usage is through the `DLataset` class, defined in `dlimp/dlimp/dataset.py`. It is a thin wrapper around `tf.data.Dataset` designed for working with datasets of trajectories; it has two creation methods, `from_tfrecords` and `from_rlds`. This library additionally provides a suite of *frame-level* and *trajectory-level* transforms designed to be used with `DLataset.frame_map` and `DLataset.traj_map`, respectively.
Scripts for converting various datasets to the dlimp TFRecord format (compatible with `DLataset.from_tfrecords`) can be found in `legacy_converters/`. This should be considered deprecated in favor of the RLDS format, converters for which can be found in `rlds_converters/` and will be expanded from now on.
from . import transforms
from .dataset import DLataset
from .utils import vmap, parallel_vmap
from typing import Optional
import tensorflow as tf
def random_resized_crop(image, scale, ratio, seed):
assert image.shape.ndims == 3 or image.shape.ndims == 4
if image.shape.ndims == 3:
image = tf.expand_dims(image, axis=0)
batch_size = tf.shape(image)[0]
# taken from https://keras.io/examples/vision/nnclr/#random-resized-crops
log_ratio = (tf.math.log(ratio[0]), tf.math.log(ratio[1]))
height = tf.shape(image)[1]
width = tf.shape(image)[2]
random_scales = tf.random.stateless_uniform((batch_size,), seed, scale[0], scale[1])
random_ratios = tf.exp(
tf.random.stateless_uniform((batch_size,), seed, log_ratio[0], log_ratio[1])
)
new_heights = tf.clip_by_value(tf.sqrt(random_scales / random_ratios), 0, 1)
new_widths = tf.clip_by_value(tf.sqrt(random_scales * random_ratios), 0, 1)
height_offsets = tf.random.stateless_uniform(
(batch_size,), seed, 0, 1 - new_heights
)
width_offsets = tf.random.stateless_uniform((batch_size,), seed, 0, 1 - new_widths)
bounding_boxes = tf.stack(
[
height_offsets,
width_offsets,
height_offsets + new_heights,
width_offsets + new_widths,
],
axis=1,
)
image = tf.image.crop_and_resize(
image, bounding_boxes, tf.range(batch_size), (height, width)
)
if image.shape[0] == 1:
return image[0]
else:
return image
def random_rot90(image, seed):
k = tf.random.stateless_uniform((), seed, 0, 4, dtype=tf.int32)
return tf.image.rot90(image, k=k)
AUGMENT_OPS = {
"random_resized_crop": random_resized_crop,
"random_brightness": tf.image.stateless_random_brightness,
"random_contrast": tf.image.stateless_random_contrast,
"random_saturation": tf.image.stateless_random_saturation,
"random_hue": tf.image.stateless_random_hue,
"random_flip_left_right": tf.image.stateless_random_flip_left_right,
"random_flip_up_down": tf.image.stateless_random_flip_up_down,
"random_rot90": random_rot90,
}
def augment_image(
image: tf.Tensor, seed: Optional[tf.Tensor] = None, **augment_kwargs
) -> tf.Tensor:
"""Unified image augmentation function for TensorFlow.
This function is primarily configured through `augment_kwargs`. There must be one kwarg called "augment_order",
which is a list of strings specifying the augmentation operations to apply and the order in which to apply them. See
the `AUGMENT_OPS` dictionary above for a list of available operations.
For each entry in "augment_order", there may be a corresponding kwarg with the same name. The value of this kwarg
can be a dictionary of kwargs or a sequence of positional args to pass to the corresponding augmentation operation.
This additional kwarg is required for all operations that take additional arguments other than the image and random
seed. For example, the "random_resized_crop" operation requires a "scale" and "ratio" argument that can be specified
either positionally or by name. "random_flip_left_right", on the other hand, does not take any additional arguments
and so does not require an additional kwarg to configure it.
Here is an example config:
```
augment_kwargs = {
"augment_order": ["random_resized_crop", "random_brightness", "random_contrast", "random_flip_left_right"],
"random_resized_crop": {
"scale": [0.8, 1.0],
"ratio": [3/4, 4/3],
},
"random_brightness": [0.1],
"random_contrast": [0.9, 1.1],
```
Args:
image: A `Tensor` of shape [height, width, channels] with the image. May be uint8 or float32 with values in [0, 255].
seed (optional): A `Tensor` of shape [2] with the seed for the random number generator.
**augment_kwargs: Keyword arguments for the augmentation operations. The order of operations is determined by
the "augment_order" keyword argument. Other keyword arguments are passed to the corresponding augmentation
operation. See above for a list of operations.
"""
if "augment_order" not in augment_kwargs:
raise ValueError("augment_kwargs must contain an 'augment_order' key.")
# convert to float at the beginning to avoid each op converting back and
# forth between uint8 and float32 internally
orig_dtype = image.dtype
image = tf.image.convert_image_dtype(image, tf.float32)
if seed is None:
seed = tf.random.uniform([2], maxval=tf.dtypes.int32.max, dtype=tf.int32)
for op in augment_kwargs["augment_order"]:
seed = tf.random.stateless_uniform([2], seed, maxval=tf.dtypes.int32.max, dtype=tf.int32)
if op in augment_kwargs:
if hasattr(augment_kwargs[op], "items"):
image = AUGMENT_OPS[op](image, seed=seed, **augment_kwargs[op])
else:
image = AUGMENT_OPS[op](image, seed=seed, *augment_kwargs[op])
else:
image = AUGMENT_OPS[op](image, seed=seed)
# float images are expected to be in [0, 1]
image = tf.clip_by_value(image, 0, 1)
# convert back to original dtype and scale
image = tf.image.convert_image_dtype(image, orig_dtype, saturate=True)
return image
import inspect
import string
from functools import partial
from typing import Any, Callable, Dict, Sequence, Union
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_datasets.core.dataset_builder import DatasetBuilder
from dlimp.utils import parallel_vmap
def _wrap(f, is_flattened):
"""Wraps a method to return a DLataset instead of a tf.data.Dataset."""
def wrapper(*args, **kwargs):
result = f(*args, **kwargs)
if not isinstance(result, DLataset) and isinstance(result, tf.data.Dataset):
# make the result a subclass of DLataset and the original class
result.__class__ = type(
"DLataset", (DLataset, type(result)), DLataset.__dict__.copy()
)
# propagate the is_flattened flag
if is_flattened is None:
result.is_flattened = f.__self__.is_flattened
else:
result.is_flattened = is_flattened
return result
return wrapper
class DLataset(tf.data.Dataset):
"""A DLimp Dataset. This is a thin wrapper around tf.data.Dataset that adds some utilities for working
with datasets of trajectories.
A DLataset starts out as dataset of trajectories, where each dataset element is a single trajectory. A
dataset element is always a (possibly nested) dictionary from strings to tensors; however, a trajectory
has the additional property that each tensor has the same leading dimension, which is the trajectory
length. Each element of the trajectory is known as a frame.
A DLataset is just a tf.data.Dataset, so you can always use standard methods like `.map` and `.filter`.
However, a DLataset is also aware of the difference between trajectories and frames, so it provides some
additional methods. To perform a transformation at the trajectory level (e.g., restructuring, relabeling,
truncating), use `.traj_map`. To perform a transformation at the frame level (e.g., image decoding,
resizing, augmentations) use `.frame_map`.
Once there are no more trajectory-level transformation to perform, you can convert to DLataset to a
dataset of frames using `.flatten`. You can still use `.frame_map` after flattening, but using `.traj_map`
will raise an error.
"""
def __getattribute__(self, name):
# monkey-patches tf.data.Dataset methods to return DLatasets
attr = super().__getattribute__(name)
if inspect.ismethod(attr):
return _wrap(attr, None)
return attr
def _apply_options(self):
"""Applies some default options for performance."""
options = tf.data.Options()
options.autotune.enabled = True
options.deterministic = True
options.experimental_optimization.apply_default_optimizations = True
options.experimental_optimization.map_fusion = True
options.experimental_optimization.map_and_filter_fusion = True
options.experimental_optimization.inject_prefetch = False
options.experimental_warm_start = True
return self.with_options(options)
def with_ram_budget(self, gb: int) -> "DLataset":
"""Sets the RAM budget for the dataset. The default is half of the available memory.
Args:
gb (int): The RAM budget in GB.
"""
options = tf.data.Options()
options.autotune.ram_budget = gb * 1024 * 1024 * 1024 # GB --> Bytes
return self.with_options(options)
@staticmethod
def from_tfrecords(
dir_or_paths: Union[str, Sequence[str]],
shuffle: bool = True,
num_parallel_reads: int = tf.data.AUTOTUNE,
) -> "DLataset":
"""Creates a DLataset from tfrecord files. The type spec of the dataset is inferred from the first file. The
only constraint is that each example must be a trajectory where each entry is either a scalar, a tensor of shape
(1, ...), or a tensor of shape (T, ...), where T is the length of the trajectory.
Args:
dir_or_paths (Union[str, Sequence[str]]): Either a directory containing .tfrecord files, or a list of paths
to tfrecord files.
shuffle (bool, optional): Whether to shuffle the tfrecord files. Defaults to True.
num_parallel_reads (int, optional): The number of tfrecord files to read in parallel. Defaults to AUTOTUNE. This
can use an excessive amount of memory if reading from cloud storage; decrease if necessary.
"""
if isinstance(dir_or_paths, str):
paths = tf.io.gfile.glob(tf.io.gfile.join(dir_or_paths, "*.tfrecord"))
else:
paths = dir_or_paths
if len(paths) == 0:
raise ValueError(f"No tfrecord files found in {dir_or_paths}")
if shuffle:
paths = tf.random.shuffle(paths)
# extract the type spec from the first file
type_spec = _get_type_spec(paths[0])
# read the tfrecords (yields raw serialized examples)
dataset = _wrap(tf.data.TFRecordDataset, False)(
paths,
num_parallel_reads=num_parallel_reads,
)._apply_options()
# decode the examples (yields trajectories)
dataset = dataset.traj_map(partial(_decode_example, type_spec=type_spec))
# broadcast traj metadata, as well as add some extra metadata (_len, _traj_index, _frame_index)
dataset = dataset.enumerate().traj_map(_broadcast_metadata)
return dataset
@staticmethod
def from_rlds(
builder: DatasetBuilder,
split: str = "train",
shuffle: bool = True,
num_parallel_reads: int = tf.data.AUTOTUNE,
) -> "DLataset":
"""Creates a DLataset from the RLDS format (which is a special case of the TFDS format).
Args:
builder (DatasetBuilder): The TFDS dataset builder to load the dataset from.
data_dir (str): The directory to load the dataset from.
split (str, optional): The split to load, specified in TFDS format. Defaults to "train".
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True.
num_parallel_reads (int, optional): The number of tfrecord files to read in parallel. Defaults to AUTOTUNE. This
can use an excessive amount of memory if reading from cloud storage; decrease if necessary.
"""
dataset = _wrap(builder.as_dataset, False)(
split=split,
shuffle_files=shuffle,
decoders={"steps": tfds.decode.SkipDecoding()},
read_config=tfds.ReadConfig(
skip_prefetch=True,
num_parallel_calls_for_interleave_files=num_parallel_reads,
interleave_cycle_length=num_parallel_reads,
),
)._apply_options()
dataset = dataset.enumerate().traj_map(_broadcast_metadata_rlds)
return dataset
def map(
self,
fn: Callable[[Dict[str, Any]], Dict[str, Any]],
num_parallel_calls=tf.data.AUTOTUNE,
**kwargs,
) -> "DLataset":
return super().map(fn, num_parallel_calls=num_parallel_calls, **kwargs)
def traj_map(
self,
fn: Callable[[Dict[str, Any]], Dict[str, Any]],
num_parallel_calls=tf.data.AUTOTUNE,
**kwargs,
) -> "DLataset":
"""Maps a function over the trajectories of the dataset. The function should take a single trajectory
as input and return a single trajectory as output.
"""
if self.is_flattened:
raise ValueError("Cannot call traj_map on a flattened dataset.")
return super().map(fn, num_parallel_calls=num_parallel_calls, **kwargs)
def frame_map(
self,
fn: Callable[[Dict[str, Any]], Dict[str, Any]],
num_parallel_calls=tf.data.AUTOTUNE,
**kwargs,
) -> "DLataset":
"""Maps a function over the frames of the dataset. The function should take a single frame as input
and return a single frame as output.
"""
if self.is_flattened:
return super().map(fn, num_parallel_calls=num_parallel_calls, **kwargs)
else:
return super().map(
parallel_vmap(fn, num_parallel_calls=num_parallel_calls),
num_parallel_calls=num_parallel_calls,
**kwargs,
)
def flatten(self, *, num_parallel_calls=tf.data.AUTOTUNE) -> "DLataset":
"""Flattens the dataset of trajectories into a dataset of frames."""
if self.is_flattened:
raise ValueError("Dataset is already flattened.")
dataset = self.interleave(
lambda traj: tf.data.Dataset.from_tensor_slices(traj),
cycle_length=num_parallel_calls,
num_parallel_calls=num_parallel_calls,
)
dataset.is_flattened = True
return dataset
def iterator(self, *, prefetch=tf.data.AUTOTUNE):
if prefetch == 0:
return self.as_numpy_iterator()
return self.prefetch(prefetch).as_numpy_iterator()
@staticmethod
def choose_from_datasets(datasets, choice_dataset, stop_on_empty_dataset=True):
if not isinstance(datasets[0], DLataset):
raise ValueError("Please pass DLatasets to choose_from_datasets.")
return _wrap(tf.data.Dataset.choose_from_datasets, datasets[0].is_flattened)(
datasets, choice_dataset, stop_on_empty_dataset=stop_on_empty_dataset
)
@staticmethod
def sample_from_datasets(
datasets,
weights=None,
seed=None,
stop_on_empty_dataset=False,
rerandomize_each_iteration=None,
):
if not isinstance(datasets[0], DLataset):
raise ValueError("Please pass DLatasets to sample_from_datasets.")
return _wrap(tf.data.Dataset.sample_from_datasets, datasets[0].is_flattened)(
datasets,
weights=weights,
seed=seed,
stop_on_empty_dataset=stop_on_empty_dataset,
rerandomize_each_iteration=rerandomize_each_iteration,
)
@staticmethod
def zip(*args, datasets=None, name=None):
if datasets is not None:
raise ValueError("Please do not pass `datasets=` to zip.")
if not isinstance(args[0], DLataset):
raise ValueError("Please pass DLatasets to zip.")
return _wrap(tf.data.Dataset.zip, args[0].is_flattened)(*args, name=name)
def _decode_example(
example_proto: tf.Tensor, type_spec: Dict[str, tf.TensorSpec]
) -> Dict[str, tf.Tensor]:
features = {key: tf.io.FixedLenFeature([], tf.string) for key in type_spec.keys()}
parsed_features = tf.io.parse_single_example(example_proto, features)
parsed_tensors = {
key: tf.io.parse_tensor(parsed_features[key], spec.dtype)
if spec is not None
else parsed_features[key]
for key, spec in type_spec.items()
}
for key in parsed_tensors:
if type_spec[key] is not None:
parsed_tensors[key] = tf.ensure_shape(
parsed_tensors[key], type_spec[key].shape
)
return parsed_tensors
def _get_type_spec(path: str) -> Dict[str, tf.TensorSpec]:
"""Get a type spec from a tfrecord file.
Args:
path (str): Path to a single tfrecord file.
Returns:
dict: A dictionary mapping feature names to tf.TensorSpecs.
"""
data = next(iter(tf.data.TFRecordDataset(path))).numpy()
example = tf.train.Example()
example.ParseFromString(data)
printable_chars = set(bytes(string.printable, "utf-8"))
out = {}
for key, value in example.features.feature.items():
data = value.bytes_list.value[0]
# stupid hack to deal with strings that are not encoded as tensors
if all(char in printable_chars for char in data):
out[key] = None
continue
tensor_proto = tf.make_tensor_proto([])
tensor_proto.ParseFromString(data)
dtype = tf.dtypes.as_dtype(tensor_proto.dtype)
shape = [d.size for d in tensor_proto.tensor_shape.dim]
if shape:
shape[0] = None # first dimension is trajectory length, which is variable
out[key] = tf.TensorSpec(shape=shape, dtype=dtype)
return out
def _broadcast_metadata(
i: tf.Tensor, traj: Dict[str, tf.Tensor]
) -> Dict[str, tf.Tensor]:
"""
Each element of a dlimp dataset is a trajectory. This means each entry must either have a leading dimension equal to
the length of the trajectory, have a leading dimension of 1, or be a scalar. Entries with a leading dimension of 1
and scalars are assumed to be trajectory-level metadata. This function broadcasts these entries to the length of the
trajectory, as well as adds the extra metadata fields `_len`, `_traj_index`, and `_frame_index`.
"""
# get the length of each dict entry
traj_lens = {
k: tf.shape(v)[0] if len(v.shape) > 0 else None for k, v in traj.items()
}
# take the maximum length as the canonical length (elements should either be the same length or length 1)
traj_len = tf.reduce_max([l for l in traj_lens.values() if l is not None])
for k in traj:
# broadcast scalars to the length of the trajectory
if traj_lens[k] is None:
traj[k] = tf.repeat(traj[k], traj_len)
traj_lens[k] = traj_len
# broadcast length-1 elements to the length of the trajectory
if traj_lens[k] == 1:
traj[k] = tf.repeat(traj[k], traj_len, axis=0)
traj_lens[k] = traj_len
asserts = [
# make sure all the lengths are the same
tf.assert_equal(
tf.size(tf.unique(tf.stack(list(traj_lens.values()))).y),
1,
message="All elements must have the same length.",
),
]
assert "_len" not in traj
assert "_traj_index" not in traj
assert "_frame_index" not in traj
traj["_len"] = tf.repeat(traj_len, traj_len)
traj["_traj_index"] = tf.repeat(i, traj_len)
traj["_frame_index"] = tf.range(traj_len)
with tf.control_dependencies(asserts):
return traj
def _broadcast_metadata_rlds(i: tf.Tensor, traj: Dict[str, Any]) -> Dict[str, Any]:
"""
In the RLDS format, each trajectory has some top-level metadata that is explicitly separated out, and a "steps"
entry. This function moves the "steps" entry to the top level, broadcasting any metadata to the length of the
trajectory. This function also adds the extra metadata fields `_len`, `_traj_index`, and `_frame_index`.
"""
steps = traj.pop("steps")
traj_len = tf.shape(tf.nest.flatten(steps)[0])[0]
# broadcast metadata to the length of the trajectory
metadata = tf.nest.map_structure(lambda x: tf.repeat(x, traj_len), traj)
# put steps back in
assert "traj_metadata" not in steps
traj = {**steps, "traj_metadata": metadata}
assert "_len" not in traj
assert "_traj_index" not in traj
assert "_frame_index" not in traj
traj["_len"] = tf.repeat(traj_len, traj_len)
traj["_traj_index"] = tf.repeat(i, traj_len)
traj["_frame_index"] = tf.range(traj_len)
return traj
from .common import *
from .frame_transforms import *
from .traj_transforms import *
from . import goal_relabeling
import fnmatch
from typing import Any, Callable, Dict, Union
def selective_tree_map(
x: Dict[str, Any],
match: Union[str, Callable[[str, Any], bool]],
map_fn: Callable,
*,
_keypath: str = "",
) -> Dict[str, Any]:
"""Maps a function over a nested dictionary, only applying it leaves that match a criterion.
If `match` is a string, it follows glob-style syntax. For example, "bar" will only match
a top-level key called "bar", "*bar" will match any leaf whose key ends with "bar",
and "*bar*" will match any subtree with a key that contains "bar".
Key paths are separated by "/". For example, "foo/bar" will match a leaf with key "bar" that
is nested under a key "foo".
Args:
x (Dict[str, Any]): The (possibly nested) dictionary to map over.
match (str or Callable[[str, Any], bool]): If a string or list of strings, `map_fn` will
only be applied to leaves whose key path matches `match` using glob-style syntax. If a
function, `map_fn` will only be applied to leaves for which `match(key_path, value)`
returns True.
map_fn (Callable): The function to apply.
"""
if not callable(match):
match_fn = lambda keypath, value: fnmatch.fnmatch(keypath, match)
else:
match_fn = match
out = {}
for key in x:
if isinstance(x[key], dict):
out[key] = selective_tree_map(
x[key], match_fn, map_fn, _keypath=_keypath + key + "/"
)
elif match_fn(_keypath + key, x[key]):
out[key] = map_fn(x[key])
else:
out[key] = x[key]
return out
def flatten_dict(d: Dict[str, Any], sep="/") -> Dict[str, Any]:
"""Given a nested dictionary, flatten it by concatenating keys with sep."""
flattened = {}
for k, v in d.items():
if isinstance(v, dict):
for k2, v2 in flatten_dict(v, sep=sep).items():
flattened[k + sep + k2] = v2
else:
flattened[k] = v
return flattened
def unflatten_dict(d: Dict[str, Any], sep="/") -> Dict[str, Any]:
"""Given a flattened dictionary, unflatten it by splitting keys by sep."""
unflattened = {}
for k, v in d.items():
keys = k.split(sep)
if len(keys) == 1:
unflattened[k] = v
else:
if keys[0] not in unflattened:
unflattened[keys[0]] = {}
unflattened[keys[0]][sep.join(keys[1:])] = v
return unflattened
from functools import partial
from typing import Any, Callable, Dict, Sequence, Tuple, Union
import tensorflow as tf
from dlimp.augmentations import augment_image
from dlimp.utils import resize_depth_image, resize_image
from .common import selective_tree_map
def decode_images(
x: Dict[str, Any], match: Union[str, Sequence[str]] = "image"
) -> Dict[str, Any]:
"""Can operate on nested dicts. Decodes any leaves that have `match` anywhere in their path."""
if isinstance(match, str):
match = [match]
return selective_tree_map(
x,
lambda keypath, value: any([s in keypath for s in match])
and value.dtype == tf.string,
partial(tf.io.decode_image, expand_animations=False),
)
def resize_images(
x: Dict[str, Any],
match: Union[str, Sequence[str]] = "image",
size: Tuple[int, int] = (128, 128),
) -> Dict[str, Any]:
"""Can operate on nested dicts. Resizes any leaves that have `match` anywhere in their path. Takes uint8 images
as input and returns float images (still in [0, 255]).
"""
if isinstance(match, str):
match = [match]
return selective_tree_map(
x,
lambda keypath, value: any([s in keypath for s in match])
and value.dtype == tf.uint8,
partial(resize_image, size=size),
)
def resize_depth_images(
x: Dict[str, Any],
match: Union[str, Sequence[str]] = "depth",
size: Tuple[int, int] = (128, 128),
) -> Dict[str, Any]:
"""Can operate on nested dicts. Resizes any leaves that have `match` anywhere in their path. Takes float32 images
as input and returns float images (in arbitrary range).
"""
if isinstance(match, str):
match = [match]
return selective_tree_map(
x,
lambda keypath, value: any([s in keypath for s in match])
and value.dtype == tf.float32,
partial(resize_depth_image, size=size),
)
def augment(
x: Dict[str, Any],
match: Union[str, Callable[[str, Any], bool]] = "*image",
traj_identical: bool = True,
keys_identical: bool = True,
augment_kwargs: dict = {},
) -> Dict[str, Any]:
"""
Augments the input dictionary `x` by applying image augmentation to all values whose keypath contains `match`.
Args:
x (Dict[str, Any]): The input dictionary to augment.
match (str or Callable[[str, Any], bool]): See documentation for `selective_tree_map`.
Defaults to "*image", which matches all leaves whose key ends in "image".
traj_identical (bool, optional): Whether to use the same random seed for all images in a trajectory.
keys_identical (bool, optional): Whether to use the same random seed for all keys that are augmented.
augment_kwargs (dict, optional): Additional keyword arguments to pass to the `augment_image` function.
"""
toplevel_seed = tf.random.uniform([2], 0, 2**31 - 1, dtype=tf.int32)
def map_fn(value):
if keys_identical and traj_identical:
seed = [x["_traj_index"], x["_traj_index"]]
elif keys_identical and not traj_identical:
seed = toplevel_seed
elif not keys_identical and traj_identical:
raise NotImplementedError()
else:
seed = None
return augment_image(value, seed=seed, **augment_kwargs)
return selective_tree_map(
x,
match,
map_fn,
)
"""
Contains goal relabeling and reward logic written in TensorFlow.
Each relabeling function takes a trajectory with keys `obs` and `next_obs`. It returns a new trajectory with the added
keys `goals` and `rewards`. Keep in mind that `obs` and `next_obs` may themselves be dictionaries, and `goals` must
match their structure.
"""
from typing import Any, Dict
import tensorflow as tf
def uniform(traj: Dict[str, Any], reached_proportion: float):
"""Relabels with a true uniform distribution over future states. With probability reached_proportion,
obs[i] gets a goal equal to next_obs[i]. In this case, the reward is 0. Otherwise,
obs[i] gets a goal sampled uniformly from the set next_obs[i + 1:], and the reward is -1.
"""
traj_len = tf.shape(tf.nest.flatten(traj)[0])[0]
# select a random future index for each transition i in the range [i + 1, traj_len)
rand = tf.random.uniform([traj_len])
low = tf.cast(tf.range(traj_len) + 1, tf.float32)
high = tf.cast(traj_len, tf.float32)
goal_idxs = tf.cast(rand * (high - low) + low, tf.int32)
# TODO(kvablack): don't know how I got an out-of-bounds during training,
# could not reproduce, trying to patch it for now
goal_idxs = tf.minimum(goal_idxs, traj_len - 1)
# select a random proportion of transitions to relabel with the next obs
goal_reached_mask = tf.random.uniform([traj_len]) < reached_proportion
# the last transition must be goal-reaching
goal_reached_mask = tf.logical_or(
goal_reached_mask, tf.range(traj_len) == traj_len - 1
)
# make goal-reaching transitions have an offset of 0
goal_idxs = tf.where(goal_reached_mask, tf.range(traj_len), goal_idxs)
# select goals
traj["goals"] = tf.nest.map_structure(
lambda x: tf.gather(x, goal_idxs),
traj["next_obs"],
)
# reward is 0 for goal-reaching transitions, -1 otherwise
traj["rewards"] = tf.cast(tf.where(goal_reached_mask, 0, -1), tf.int32)
return traj
def last_state_upweighted(traj: Dict[str, Any], reached_proportion: float):
"""
A weird relabeling scheme where the last state gets upweighted. For each transition i, a uniform random number is
generated in the range [i + 1, i + traj_len). It then gets clipped to be less than traj_len. Therefore, the first
transition (i = 0) gets a goal sampled uniformly from the future, but for i > 0 the last state gets more and more
upweighted.
"""
traj_len = tf.shape(tf.nest.flatten(traj)[0])[0]
# select a random future index for each transition
offsets = tf.random.uniform(
[traj_len],
minval=1,
maxval=traj_len,
dtype=tf.int32,
)
# select random transitions to relabel as goal-reaching
goal_reached_mask = tf.random.uniform([traj_len]) < reached_proportion
# last transition is always goal-reaching
goal_reached_mask = tf.logical_or(
goal_reached_mask, tf.range(traj_len) == traj_len - 1
)
# the goal will come from the current transition if the goal was reached
offsets = tf.where(goal_reached_mask, 0, offsets)
# convert from relative to absolute indices
indices = tf.range(traj_len) + offsets
# clamp out of bounds indices to the last transition
indices = tf.minimum(indices, traj_len - 1)
# select goals
traj["goals"] = tf.nest.map_structure(
lambda x: tf.gather(x, indices),
traj["next_obs"],
)
# reward is 0 for goal-reaching transitions, -1 otherwise
traj["rewards"] = tf.cast(tf.where(goal_reached_mask, 0, -1), tf.int32)
return traj
def geometric(traj: Dict[str, Any], reached_proportion: float, discount: float):
"""
Relabels with a geometric distribution over future states. With probability reached_proportion, obs[i] gets
a goal equal to next_obs[i]. In this case, the reward is 0. Otherwise, obs[i] gets a goal sampled
geometrically from the set next_obs[i + 1:], and the reward is -1.
"""
traj_len = tf.shape(tf.nest.flatten(traj)[0])[0]
# geometrically select a future index for each transition i in the range [i + 1, traj_len)
arange = tf.range(traj_len)
is_future_mask = tf.cast(arange[:, None] < arange[None], tf.float32)
d = discount ** tf.cast(arange[None] - arange[:, None], tf.float32)
probs = is_future_mask * d
# The indexing changes the shape from [seq_len, 1] to [seq_len]
goal_idxs = tf.random.categorical(
logits=tf.math.log(probs), num_samples=1, dtype=tf.int32
)[:, 0]
# select a random proportion of transitions to relabel with the next obs
goal_reached_mask = tf.random.uniform([traj_len]) < reached_proportion
# the last transition must be goal-reaching
goal_reached_mask = tf.logical_or(
goal_reached_mask, tf.range(traj_len) == traj_len - 1
)
# make goal-reaching transitions have an offset of 0
goal_idxs = tf.where(goal_reached_mask, tf.range(traj_len), goal_idxs)
# select goals
traj["goals"] = tf.nest.map_structure(
lambda x: tf.gather(x, goal_idxs),
traj["next_obs"],
)
# reward is 0 for goal-reaching transitions, -1 otherwise
traj["rewards"] = tf.cast(tf.where(goal_reached_mask, 0, -1), tf.int32)
return traj
from typing import Any, Dict
import tensorflow as tf
def add_next_obs(traj: Dict[str, Any], pad: bool = True) -> Dict[str, Any]:
"""
Given a trajectory with a key "observations", add the key "next_observations". If pad is False, discards the last
value of all other keys. Otherwise, the last transition will have "observations" == "next_observations".
"""
if not pad:
traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj)
traj_truncated["next_observations"] = tf.nest.map_structure(
lambda x: x[1:], traj["observations"]
)
return traj_truncated
else:
traj["next_observations"] = tf.nest.map_structure(
lambda x: tf.concat((x[1:], x[-1:]), axis=0), traj["observations"]
)
return traj
from typing import Callable, Tuple
import tensorflow as tf
def tensor_feature(value):
return tf.train.Feature(
bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(value).numpy()])
)
def resize_image(image: tf.Tensor, size: Tuple[int, int]) -> tf.Tensor:
"""Resizes an image using Lanczos3 interpolation. Expects & returns uint8."""
assert image.dtype == tf.uint8
image = tf.image.resize(image, size, method="lanczos3", antialias=True)
image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8)
return image
def resize_depth_image(depth_image: tf.Tensor, size: Tuple[int, int]) -> tf.Tensor:
"""Resizes a depth image using bilinear interpolation. Expects & returns float32 in arbitrary range."""
assert depth_image.dtype == tf.float32
if len(depth_image.shape) < 3:
depth_image = tf.image.resize(
depth_image[..., None], size, method="bilinear", antialias=True
)[..., 0]
else:
depth_image = tf.image.resize(
depth_image, size, method="bilinear", antialias=True
)
return depth_image
def read_resize_encode_image(path: str, size: Tuple[int, int]) -> tf.Tensor:
"""Reads, decodes, resizes, and then re-encodes an image."""
data = tf.io.read_file(path)
image = tf.image.decode_jpeg(data)
image = resize_image(image, size)
image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8)
return tf.io.encode_jpeg(image, quality=95)
def vmap(fn: Callable) -> Callable:
"""
Vmap a function over the first dimension of a tensor (or nested structure of tensors). This
version does NOT parallelize the function; however, it fuses the function calls in a way that
appears to be more performant than tf.map_fn or tf.vectorized_map (when falling back to
while_loop) for certain situations.
Requires the first dimension of the input to be statically known.
"""
def wrapped(structure):
return tf.nest.map_structure(
lambda *x: tf.stack(x),
*[
fn(tf.nest.pack_sequence_as(structure, x))
for x in zip(*map(tf.unstack, tf.nest.flatten(structure)))
],
)
return wrapped
def parallel_vmap(fn: Callable, num_parallel_calls=tf.data.AUTOTUNE) -> Callable:
"""
Vmap a function over the first dimension of a tensor (or nested structure of tensors). This
version attempts to parallelize the function using the tf.data API. I found this to be more
performant than tf.map_fn or tf.vectorized_map (when falling back to while_loop), but the batch
call appears to add significant overhead that may make it slower for some situations.
"""
def wrapped(structure):
return (
tf.data.Dataset.from_tensor_slices(structure)
.map(fn, deterministic=True, num_parallel_calls=num_parallel_calls)
.batch(
tf.cast(tf.shape(tf.nest.flatten(structure)[0])[0], tf.int64),
)
.get_single_element()
)
return wrapped
"""
Converts data from the BridgeData raw format to TFRecord format.
Consider the following directory structure for the input data:
bridgedata_raw/
rss/
toykitchen2/
set_table/
00/
2022-01-01_00-00-00/
collection_metadata.json
config.json
diagnostics.png
raw/
traj_group0/
traj0/
obs_dict.pkl
policy_out.pkl
agent_data.pkl
images0/
im_0.jpg
im_1.jpg
...
...
...
01/
...
The --depth parameter controls how much of the data to process at the
--input_path; for example, if --depth=5, then --input_path should be
"bridgedata_raw", and all data will be processed. If --depth=3, then
--input_path should be "bridgedata_raw/rss/toykitchen2", and only data
under "toykitchen2" will be processed.
Can write directly to Google Cloud Storage, but not read from it.
"""
import glob
import logging
import os
import pickle
import random
from datetime import datetime
from functools import partial
from multiprocessing import Pool
import numpy as np
import tensorflow as tf
import tqdm
from absl import app, flags
from tqdm_multiprocess import TqdmMultiProcessPool
import dlimp as dl
from dlimp.utils import read_resize_encode_image, tensor_feature
FLAGS = flags.FLAGS
flags.DEFINE_string("input_path", None, "Input path", required=True)
flags.DEFINE_string("output_path", None, "Output path", required=True)
flags.DEFINE_integer(
"depth",
5,
"Number of directories deep to traverse to the dated directory. Looks for"
"{input_path}/dir_1/dir_2/.../dir_{depth-1}/2022-01-01_00-00-00/...",
)
flags.DEFINE_bool("overwrite", False, "Overwrite existing files")
flags.DEFINE_float(
"train_proportion", 0.9, "Proportion of data to use for training (rather than val)"
)
flags.DEFINE_integer("num_workers", 8, "Number of threads to use")
flags.DEFINE_integer("shard_size", 200, "Maximum number of trajectories per shard")
IMAGE_SIZE = (256, 256)
CAMERA_VIEWS = {"images0", "images1", "images2"}
def process_images(path): # processes images at a trajectory level
image_dirs = set(os.listdir(str(path))).intersection(CAMERA_VIEWS)
image_paths = [
sorted(
glob.glob(os.path.join(path, image_dir, "im_*.jpg")),
key=lambda x: int(x.split("_")[-1].split(".")[0]),
)
for image_dir in image_dirs
]
filenames = [[path.split("/")[-1] for path in x] for x in image_paths]
assert all(x == filenames[0] for x in filenames)
d = {
image_dir: [read_resize_encode_image(path, IMAGE_SIZE) for path in p]
for image_dir, p in zip(image_dirs, image_paths)
}
for missing in CAMERA_VIEWS - set(d.keys()):
d[missing] = [""] * len(
image_paths[0]
) # empty string is a placeholder for missing images
return d
def process_state(path):
fp = os.path.join(path, "obs_dict.pkl")
with open(fp, "rb") as f:
x = pickle.load(f)
return x["full_state"]
def process_actions(path):
fp = os.path.join(path, "policy_out.pkl")
with open(fp, "rb") as f:
act_list = pickle.load(f)
if isinstance(act_list[0], dict):
act_list = [x["actions"] for x in act_list]
return act_list
def process_lang(path):
fp = os.path.join(path, "lang.txt")
text = "" # empty string is a placeholder for missing text
if os.path.exists(fp):
with open(fp, "r") as f:
text = f.readline().strip()
return text
# create a tfrecord for a group of trajectories
def create_tfrecord(paths, output_path, tqdm_func, global_tqdm):
writer = tf.io.TFRecordWriter(output_path)
for path in paths:
try:
# Data collected prior to 7-23 has a delay of 1, otherwise a delay of 0
date_time = datetime.strptime(path.split("/")[-4], "%Y-%m-%d_%H-%M-%S")
latency_shift = date_time < datetime(2021, 7, 23)
out = dict()
out["obs"] = process_images(path)
out["obs"]["state"] = process_state(path)
out["actions"] = process_actions(path)
out["lang"] = process_lang(path)
# shift the actions according to camera latency
if latency_shift:
out["obs"] = {k: v[1:] for k, v in out["obs"].items()}
out["actions"] = out["actions"][:-1]
# append a null action to the end
out["actions"].append(np.zeros_like(out["actions"][0]))
assert (
len(out["actions"])
== len(out["obs"]["state"])
== len(out["obs"]["images0"])
)
example = tf.train.Example(
features=tf.train.Features(
feature={
k: tensor_feature(v)
for k, v in dl.transforms.flatten_dict(out).items()
}
)
)
writer.write(example.SerializeToString())
except Exception as e:
import sys
import traceback
traceback.print_exc()
logging.error(f"Error processing {path}")
sys.exit(1)
global_tqdm.update(1)
writer.close()
global_tqdm.write(f"Finished {output_path}")
def get_traj_paths(path, train_proportion):
train_traj = []
val_traj = []
for dated_folder in os.listdir(path):
# a mystery left by the greats of the past
if "lmdb" in dated_folder:
continue
search_path = os.path.join(path, dated_folder, "raw", "traj_group*", "traj*")
all_traj = glob.glob(search_path)
if not all_traj:
logging.info(f"no trajs found in {search_path}")
continue
random.shuffle(all_traj)
train_traj += all_traj[: int(len(all_traj) * train_proportion)]
val_traj += all_traj[int(len(all_traj) * train_proportion) :]
return train_traj, val_traj
def main(_):
assert FLAGS.depth >= 1
if tf.io.gfile.exists(FLAGS.output_path):
if FLAGS.overwrite:
logging.info(f"Deleting {FLAGS.output_path}")
tf.io.gfile.rmtree(FLAGS.output_path)
else:
logging.info(f"{FLAGS.output_path} exists, exiting")
return
# each path is a directory that contains dated directories
paths = glob.glob(os.path.join(FLAGS.input_path, *("*" * (FLAGS.depth - 1))))
# get trajecotry paths in parallel
with Pool(FLAGS.num_workers) as p:
train_paths, val_paths = zip(
*p.map(
partial(get_traj_paths, train_proportion=FLAGS.train_proportion), paths
)
)
train_paths = [x for y in train_paths for x in y]
val_paths = [x for y in val_paths for x in y]
random.shuffle(train_paths)
random.shuffle(val_paths)
# shard paths
train_shards = np.array_split(
train_paths, np.ceil(len(train_paths) / FLAGS.shard_size)
)
val_shards = np.array_split(val_paths, np.ceil(len(val_paths) / FLAGS.shard_size))
# create output paths
tf.io.gfile.makedirs(os.path.join(FLAGS.output_path, "train"))
tf.io.gfile.makedirs(os.path.join(FLAGS.output_path, "val"))
train_output_paths = [
os.path.join(FLAGS.output_path, "train", f"{i}.tfrecord")
for i in range(len(train_shards))
]
val_output_paths = [
os.path.join(FLAGS.output_path, "val", f"{i}.tfrecord")
for i in range(len(val_shards))
]
# create tasks (see tqdm_multiprocess documenation)
tasks = [
(create_tfrecord, (train_shards[i], train_output_paths[i]))
for i in range(len(train_shards))
] + [
(create_tfrecord, (val_shards[i], val_output_paths[i]))
for i in range(len(val_shards))
]
# run tasks
pool = TqdmMultiProcessPool(FLAGS.num_workers)
with tqdm.tqdm(
total=len(train_paths) + len(val_paths),
dynamic_ncols=True,
position=0,
desc="Total progress",
) as pbar:
pool.map(pbar, tasks, lambda _: None, lambda _: None)
if __name__ == "__main__":
app.run(main)
"""
Converts data from a preprocessed Ego4D format to TFRecord format.
Expects a manifest.csv file with paths to directories containing JPEG files. Images should be 224x224.
"""
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import tqdm
from absl import app, flags, logging
from tqdm_multiprocess import TqdmMultiProcessPool
from dlimp.utils import read_resize_encode_image, tensor_feature
FLAGS = flags.FLAGS
flags.DEFINE_string("input_path", None, "Input path", required=True)
flags.DEFINE_string("output_path", None, "Output path", required=True)
flags.DEFINE_bool("overwrite", False, "Overwrite existing files")
flags.DEFINE_float(
"train_proportion", 0.9, "Proportion of data to use for training (rather than val)"
)
flags.DEFINE_integer("num_workers", 8, "Number of threads to use")
flags.DEFINE_integer("shard_size", 200, "Maximum number of trajectories per shard")
IMAGE_SIZE = (224, 224)
# create a tfrecord for a group of trajectories
def create_tfrecord(manifest, output_path, tqdm_func, global_tqdm):
writer = tf.io.TFRecordWriter(output_path)
for _, row in manifest.iterrows():
# left-zero-pad the frame indices to length 6; this weird way of doing it is left over from when this had to be
# done in a tf graph; I'm leaving it in case I need it again someday
indices = tf.as_string(tf.range(row["num_frames"]))
indices = tf.strings.bytes_split(indices)
n = 6 - indices.row_lengths()
zeros = tf.fill([tf.reduce_sum(n)], "0")
zeros = tf.RaggedTensor.from_row_lengths(zeros, n)
padded = tf.concat([zeros, indices], axis=1)
padded = tf.strings.reduce_join(padded, axis=1)
# get the paths to all of the frames
paths = tf.strings.join([row["directory"], "/", padded, ".jpg"])
# read, resize, and re-encode the images
images = [read_resize_encode_image(path, IMAGE_SIZE) for path in paths]
example = tf.train.Example(
features=tf.train.Features(
feature={
"obs": tensor_feature(images),
"lang": tensor_feature(row["text"]),
}
)
)
writer.write(example.SerializeToString())
global_tqdm.update(1)
writer.close()
global_tqdm.write(f"Finished {output_path}")
def main(_):
if tf.io.gfile.exists(FLAGS.output_path):
if FLAGS.overwrite:
logging.info(f"Deleting {FLAGS.output_path}")
tf.io.gfile.rmtree(FLAGS.output_path)
else:
logging.info(f"{FLAGS.output_path} exists, exiting")
return
# get the manifest
manifest = pd.read_csv(os.path.join(FLAGS.input_path, "manifest.csv"))
assert list(manifest.columns) == ["index", "directory", "num_frames", "text"]
# get rid of the invalid path prefixes and replace them with the actual
# dataset path prefix
manifest["directory"] = manifest["directory"].apply(
lambda x: os.path.join(FLAGS.input_path, *x.strip("/").split("/")[-2:])
)
# train/val split
manifest = manifest.sample(frac=1.0, random_state=0)
train_manifest = manifest.iloc[: int(len(manifest) * FLAGS.train_proportion)]
val_manifest = manifest.iloc[int(len(manifest) * FLAGS.train_proportion) :]
# shard paths
train_shards = np.array_split(
train_manifest, np.ceil(len(train_manifest) / FLAGS.shard_size)
)
val_shards = np.array_split(
val_manifest, np.ceil(len(val_manifest) / FLAGS.shard_size)
)
# create output paths
tf.io.gfile.makedirs(os.path.join(FLAGS.output_path, "train"))
tf.io.gfile.makedirs(os.path.join(FLAGS.output_path, "val"))
train_output_paths = [
os.path.join(FLAGS.output_path, "train", f"{i}.tfrecord")
for i in range(len(train_shards))
]
val_output_paths = [
os.path.join(FLAGS.output_path, "val", f"{i}.tfrecord")
for i in range(len(val_shards))
]
# create tasks (see tqdm_multiprocess documenation)
tasks = [
(create_tfrecord, (train_shards[i], train_output_paths[i]))
for i in range(len(train_shards))
] + [
(create_tfrecord, (val_shards[i], val_output_paths[i]))
for i in range(len(val_shards))
]
# run tasks
pool = TqdmMultiProcessPool(FLAGS.num_workers)
with tqdm.tqdm(
total=len(manifest),
dynamic_ncols=True,
position=0,
desc="Total progress",
) as pbar:
pool.map(pbar, tasks, lambda _: None, lambda _: None)
if __name__ == "__main__":
app.run(main)
Kinetics dataset is a set of short, 10 second, youtube clips with associated labels.
The dataset is split into 400, 600, and 700 classes. The dataset is available at https://deepmind.com/research/open-source/kinetics.
For seamless integration with DLIMP, follow instructions in https://github.com/cvdfoundation/kinetics-dataset for downloading the specific kinetics dataset of your choosing.
Then to preprocess for DLIMP, run num: 400, 600, 700:
```
python3 -m scripts.kinetics.raw_to_tfrecord --input_path /path-to-kinetics-dataset/k{num} --output_path /path-to-output/ --aspect_ratio True
```
"""
Converts data from the kinetics raw mp4 format to TFRecord format.
The assumptions of the data format are as follows:
k400/
annotations/
train.csv
val.csv
train/
id{time_start}_{time_end}.mp4
...
val/
id{time_start}_{time_end}.mp4
...
The --input_path should be the path to the k400 directory, and the --output_path
should be the path to the directory where the TFRecord files will be written.
--aspect_ratio controls whether the videos are first center cropped to 4:3 aspect before being resized to 240x240.
Follow instructions in https://github.com/cvdfoundation/kinetics-dataset for downloading the specific kinetics dataset of your choosing.
"""
import os
import imageio
import numpy as np
import pandas as pd
import tensorflow as tf
import tqdm
from absl import app, flags, logging
from tqdm_multiprocess import TqdmMultiProcessPool
from dlimp.utils import resize_image, tensor_feature
FLAGS = flags.FLAGS
flags.DEFINE_string("input_path", None, "Input path", required=True)
flags.DEFINE_string("output_path", None, "Output path", required=True)
flags.DEFINE_bool("overwrite", False, "Overwrite existing files")
flags.DEFINE_bool("aspect_ratio", False, "Whether to preserve aspect ratio")
flags.DEFINE_integer("num_workers", 8, "Number of threads to use")
flags.DEFINE_integer("shard_size", 400, "Maximum number of trajectories per shard")
# create a tfrecord for a group of trajectories
def create_tfrecord(shard, output_path, tqdm_func, global_tqdm):
writer = tf.io.TFRecordWriter(output_path)
for item in shard:
try:
video = imageio.mimread(item["path"], format="mp4", memtest=False)
except OSError:
# corrupted video
global_tqdm.update(1)
continue
if FLAGS.aspect_ratio:
# height and width both varies in this dataset
height = video[0].shape[0]
width = video[0].shape[1]
ratio = width / height
if ratio > 4 / 3:
# center crop horizontally
desired_width = int(np.round(4 / 3 * height))
video = [
image[
:, (width - desired_width) // 2 : (width + desired_width) // 2
]
for image in video
]
elif ratio < 4 / 3:
# center crop vertically
desired_height = int(np.round(3 / 4 * width))
video = [
image[
(height - desired_height) // 2 : (height + desired_height) // 2,
:,
]
for image in video
]
# now resize to square 240x240
video = [resize_image(image, (240, 240)) for image in video]
video = [
tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8)
for image in video
]
# encode
video = [tf.io.encode_jpeg(image, quality=95) for image in video]
example = tf.train.Example(
features=tf.train.Features(
feature={
"obs": tensor_feature(video),
"label": tensor_feature(item["label"]),
}
)
)
writer.write(example.SerializeToString())
global_tqdm.update(1)
writer.close()
global_tqdm.write(f"Finished {output_path}")
def main(_):
tf.config.set_visible_devices(
[], "GPU"
) # TF might look for GPUs and crash out if it finds one
if tf.io.gfile.exists(FLAGS.output_path):
if FLAGS.overwrite:
logging.info(f"Deleting {FLAGS.output_path}")
tf.io.gfile.rmtree(FLAGS.output_path)
else:
logging.info(f"{FLAGS.output_path} exists, exiting")
return
# load annotations
with open(os.path.join(FLAGS.input_path, "annotations", "train.csv"), "r") as f:
train_annotations = pd.read_csv(f)
with open(os.path.join(FLAGS.input_path, "annotations", "val.csv"), "r") as f:
val_annotations = pd.read_csv(f)
# filter
train_annotations = [
(row["label"], row["youtube_id"], row["time_start"], row["time_end"])
for idx, row in train_annotations.iterrows()
]
val_annotations = [
(row["label"], row["youtube_id"], row["time_start"], row["time_end"])
for idx, row in val_annotations.iterrows()
]
print(f"------ Train: {len(train_annotations)}, Val: {len(val_annotations)} ------")
train = []
count = 0
for label, youtube_id, time_start, time_end in train_annotations:
# Downloader was made by geniuses as you can tell
path = (
f"{FLAGS.input_path}/train/{youtube_id}_{time_start:06d}_{time_end:06d}.mp4"
)
if not os.path.exists(path):
count += 1
continue
train.append({"path": path, "label": label})
print("Number of train files not found: ", count)
val = []
count = 0
for label, youtube_id, time_start, time_end in val_annotations:
path = (
f"{FLAGS.input_path}/val/{youtube_id}_{time_start:06d}_{time_end:06d}.mp4"
)
if not os.path.exists(path):
count += 1
continue
val.append({"path": path, "label": label})
print("Number of val files not found: ", count)
# shard
train_shards = np.array_split(train, np.ceil(len(train) / FLAGS.shard_size))
val_shards = np.array_split(val, np.ceil(len(val) / FLAGS.shard_size))
# create output paths
tf.io.gfile.makedirs(os.path.join(FLAGS.output_path, "train"))
tf.io.gfile.makedirs(os.path.join(FLAGS.output_path, "val"))
train_output_paths = [
os.path.join(FLAGS.output_path, "train", f"{i}.tfrecord")
for i in range(len(train_shards))
]
val_output_paths = [
os.path.join(FLAGS.output_path, "val", f"{i}.tfrecord")
for i in range(len(val_shards))
]
# create tasks (see tqdm_multiprocess documenation)
tasks = [
(create_tfrecord, (train_shards[i], train_output_paths[i]))
for i in range(len(train_shards))
] + [
(create_tfrecord, (val_shards[i], val_output_paths[i]))
for i in range(len(val_shards))
]
# run tasks
pool = TqdmMultiProcessPool(FLAGS.num_workers)
with tqdm.tqdm(
total=len(train) + len(val),
dynamic_ncols=True,
position=0,
desc="Total progress",
) as pbar:
pool.map(pbar, tasks, lambda _: None, lambda _: None)
if __name__ == "__main__":
app.run(main)
{
"Attaching something to something": "1",
"Closing something": "5",
"Covering something with something": "6",
"Folding something": "14",
"Laying something on the table on its side, not upright": "21",
"Lifting something up completely without letting it drop down": "27",
"Lifting something up completely, then letting it drop down": "28",
"Lifting something with something on it": "29",
"Lifting up one end of something without letting it drop down": "30",
"Lifting up one end of something, then letting it drop down": "31",
"Moving something across a surface without it falling down": "35",
"Moving something and something away from each other": "36",
"Moving something and something closer to each other": "37",
"Moving something and something so they collide with each other": "38",
"Moving something and something so they pass each other": "39",
"Moving something away from something": "40",
"Moving something away from the camera": "41",
"Moving something closer to something": "42",
"Moving something down": "43",
"Moving something towards the camera": "44",
"Moving something up": "45",
"Opening something": "46",
"Picking something up": "47",
"Plugging something into something": "49",
"Pouring something into something": "59",
"Pouring something onto something": "61",
"Pouring something out of something": "62",
"Pulling something from behind of something": "85",
"Pulling something from left to right": "86",
"Pulling something from right to left": "87",
"Pulling something onto something": "88",
"Pulling something out of something": "89",
"Pushing something from left to right": "93",
"Pushing something from right to left": "94",
"Pushing something off of something": "95",
"Pushing something onto something": "96",
"Pushing something so that it slightly moves": "100",
"Pushing something with something": "101",
"Putting something and something on the table": "103",
"Putting something behind something": "104",
"Putting something in front of something": "105",
"Putting something into something": "106",
"Putting something next to something": "107",
"Putting something on a flat surface without letting it roll": "108",
"Putting something on a surface": "109",
"Putting something onto a slanted surface but it doesn't glide down": "111",
"Putting something onto something": "112",
"Putting something similar to other things that are already on the table": "114",
"Putting something that cannot actually stand upright upright on the table, so it falls on its side": "117",
"Putting something underneath something": "118",
"Putting something upright on the table": "119",
"Putting something, something and something on the table": "120",
"Removing something, revealing something behind": "121",
"Scooping something up with something": "123",
"Stacking number of something": "144",
"Taking one of many similar things on the table": "146",
"Taking something from somewhere": "147",
"Taking something out of something": "148",
"Turning something upside down": "164",
"Uncovering something": "171",
"Unfolding something": "172"
}
"""
Converts data from a raw somethingsomething format to TFRecord format.
"""
import json
import os
from email.mime import image
import imageio
import numpy as np
import tensorflow as tf
import tqdm
from absl import app, flags, logging
from tqdm_multiprocess import TqdmMultiProcessPool
from dlimp.utils import resize_image, tensor_feature
FLAGS = flags.FLAGS
flags.DEFINE_string("input_path", None, "Input path", required=True)
flags.DEFINE_string("label_path", None, "Labels to filter by", required=True)
flags.DEFINE_string("output_path", None, "Output path", required=True)
flags.DEFINE_bool("overwrite", False, "Overwrite existing files")
flags.DEFINE_integer("num_workers", 8, "Number of threads to use")
flags.DEFINE_integer("shard_size", 200, "Maximum number of trajectories per shard")
# create a tfrecord for a group of trajectories
def create_tfrecord(shard, output_path, tqdm_func, global_tqdm):
writer = tf.io.TFRecordWriter(output_path)
for item in shard:
video = imageio.mimread(item["path"])
# center crop to 4:3 aspect ratio (same as bridge)
# height is always 240, width varies
width = video[0].shape[1]
if width > 320:
# center crop horizontally to 320
video = [
image[:, (width - 320) // 2 : (width + 320) // 2] for image in video
]
elif width < 320:
# center crop vertically
desired_height = int(np.round(3 / 4 * width))
video = [
image[(240 - desired_height) // 2 : (240 + desired_height) // 2, :]
for image in video
]
assert all(
np.isclose(image.shape[1] / image.shape[0], 4 / 3, atol=0.01)
for image in video
)
# now resize to square 240x240
video = [resize_image(image, (240, 240)) for image in video]
video = [
tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8)
for image in video
]
# encode
video = [tf.io.encode_jpeg(image, quality=95) for image in video]
example = tf.train.Example(
features=tf.train.Features(
feature={
"obs": tensor_feature(video),
"lang": tensor_feature(item["lang"]),
}
)
)
writer.write(example.SerializeToString())
global_tqdm.update(1)
writer.close()
global_tqdm.write(f"Finished {output_path}")
def main(_):
if tf.io.gfile.exists(FLAGS.output_path):
if FLAGS.overwrite:
logging.info(f"Deleting {FLAGS.output_path}")
tf.io.gfile.rmtree(FLAGS.output_path)
else:
logging.info(f"{FLAGS.output_path} exists, exiting")
return
# load annotations
with open(os.path.join(FLAGS.input_path, "annotations", "train.json"), "r") as f:
train_annotations = json.load(f)
with open(
os.path.join(FLAGS.input_path, "annotations", "validation.json"), "r"
) as f:
val_annotations = json.load(f)
# load labels to filter by
with open(FLAGS.label_path, "r") as f:
labels = set(json.load(f).keys())
# filter
train_annotations = [
x
for x in train_annotations
if x["template"].replace("[", "").replace("]", "") in labels
]
val_annotations = [
x
for x in val_annotations
if x["template"].replace("[", "").replace("]", "") in labels
]
print(f"------ Train: {len(train_annotations)}, Val: {len(val_annotations)} ------")
# get video paths
train = [
{
"path": os.path.join(
FLAGS.input_path, "20bn-something-something-v2", x["id"]
)
+ ".webm",
"lang": x["label"],
}
for x in train_annotations
]
val = [
{
"path": os.path.join(
FLAGS.input_path, "20bn-something-something-v2", x["id"]
)
+ ".webm",
"lang": x["label"],
}
for x in val_annotations
]
# shard
train_shards = np.array_split(train, np.ceil(len(train) / FLAGS.shard_size))
val_shards = np.array_split(val, np.ceil(len(val) / FLAGS.shard_size))
# create output paths
tf.io.gfile.makedirs(os.path.join(FLAGS.output_path, "train"))
tf.io.gfile.makedirs(os.path.join(FLAGS.output_path, "val"))
train_output_paths = [
os.path.join(FLAGS.output_path, "train", f"{i}.tfrecord")
for i in range(len(train_shards))
]
val_output_paths = [
os.path.join(FLAGS.output_path, "val", f"{i}.tfrecord")
for i in range(len(val_shards))
]
# create tasks (see tqdm_multiprocess documenation)
tasks = [
(create_tfrecord, (train_shards[i], train_output_paths[i]))
for i in range(len(train_shards))
] + [
(create_tfrecord, (val_shards[i], val_output_paths[i]))
for i in range(len(val_shards))
]
# run tasks
pool = TqdmMultiProcessPool(FLAGS.num_workers)
with tqdm.tqdm(
total=len(train) + len(val),
dynamic_ncols=True,
position=0,
desc="Total progress",
) as pbar:
pool.map(pbar, tasks, lambda _: None, lambda _: None)
if __name__ == "__main__":
app.run(main)
Contains converters to the [RLDS format](https://github.com/google-research/rlds), which is a specification on top of the [TFDS](https://www.tensorflow.org/datasets) (TensorFlow datasets) format, which is for the most part built on top of the TFRecord format. RLDS datasets can be loaded using `dlimp.DLataset.from_rlds`.
Out of the box, TFDS only supports single-threaded dataset conversion and distributed dataset conversion using Apache Beam. `dataset_builder.py` contains a more middle-ground implementation that uses Python multiprocessing to parallelize conversion on a single machine. It is based heavily on Karl Pertsch's implementation (see [kpertsch/bridge_rlds_builder](https://github.com/kpertsch/bridge_rlds_builder/blob/f0d16c5a8384c1476aa1c274a9aef3a5f76cbada/bridge_dataset/conversion_utils.py)).
## Usage
Each subdirectory contains a specific dataset converter implementation that inherits from the `dataset_builder.MultiThreadedDatasetBuilder` class. First, install the multithreaded dataset builder by running `pip install .` in this directory. Each dataset converter may have additional requirements that they specify using a `requirements.txt`.
To build a particular dataset, `cd` into its corresponding directory and run `CUDA_VISIBLE_DEVICES="" tfds build --manual_dir <path_to_raw_data>`. See individual dataset documentation for how to obtain the raw data. You may also want to modify settings inside the `<dataset_name>_dataset_builder.py` file (e.g., `NUM_WORKERS` and `CHUNKSIZE`.)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment