Commit 32e19929 authored by Roman Shapovalov's avatar Roman Shapovalov Committed by Facebook GitHub Bot
Browse files

SQL Index Dataset

Summary:
Moving SQL dataset to PyTorch3D. It has been extensively tested in pixar_replay.

It requires SQLAlchemy 2.0, which is not supported in fbcode. So I exclude the sources and tests that depend on it from buck TARGETS.

Reviewed By: bottler

Differential Revision: D45086611

fbshipit-source-id: 0285f03e5824c0478c70ad13731525bb5ec7deef
parent 7aeedd17
......@@ -450,6 +450,7 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
self,
frame_annotation: types.FrameAnnotation,
sequence_annotation: types.SequenceAnnotation,
load_blobs: bool = True,
) -> FrameDataSubtype:
"""An abstract method to build the frame data based on raw frame/sequence
annotations, load the binary data and adjust them according to the metadata.
......@@ -465,8 +466,9 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
Beware that modifications of frame data are done in-place.
Args:
dataset_root: The root folder of the dataset; all the paths in jsons are
specified relative to this root (but not json paths themselves).
dataset_root: The root folder of the dataset; all paths in frame / sequence
annotations are defined w.r.t. this root. Has to be set if any of the
load_* flabs below is true.
load_images: Enable loading the frame RGB data.
load_depths: Enable loading the frame depth maps.
load_depth_masks: Enable loading the frame depth map masks denoting the
......@@ -494,7 +496,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
path_manager: Optionally a PathManager for interpreting paths in a special way.
"""
dataset_root: str = ""
dataset_root: Optional[str] = None
load_images: bool = True
load_depths: bool = True
load_depth_masks: bool = True
......@@ -510,6 +512,25 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
box_crop_context: float = 0.3
path_manager: Any = None
def __post_init__(self) -> None:
load_any_blob = (
self.load_images
or self.load_depths
or self.load_depth_masks
or self.load_masks
or self.load_point_clouds
)
if load_any_blob and self.dataset_root is None:
raise ValueError(
"dataset_root must be set to load any blob data. "
"Make sure it is set in either FrameDataBuilder or Dataset params."
)
if load_any_blob and not os.path.isdir(self.dataset_root): # pyre-ignore
raise ValueError(
f"dataset_root is passed but {self.dataset_root} does not exist."
)
def build(
self,
frame_annotation: types.FrameAnnotation,
......@@ -567,7 +588,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
if bbox_xywh is None and fg_mask_np is not None:
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long)
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
if frame_annotation.image is not None:
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
......@@ -612,7 +633,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
def _load_fg_probability(
self, entry: types.FrameAnnotation
) -> Tuple[np.ndarray, str]:
full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore
assert self.dataset_root is not None and entry.mask is not None
full_path = os.path.join(self.dataset_root, entry.mask.path)
fg_probability = load_mask(self._local_path(full_path))
if fg_probability.shape[-2:] != entry.image.size:
raise ValueError(
......@@ -647,7 +669,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
fg_probability: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str, torch.Tensor]:
entry_depth = entry.depth
assert entry_depth is not None
assert self.dataset_root is not None and entry_depth is not None
path = os.path.join(self.dataset_root, entry_depth.path)
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
......@@ -657,6 +679,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
if self.load_depth_masks:
assert entry_depth.mask_path is not None
# pyre-ignore
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
depth_mask = load_depth_mask(self._local_path(mask_path))
else:
......@@ -705,6 +728,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
)
if path.startswith(unwanted_prefix):
path = path[len(unwanted_prefix) :]
assert self.dataset_root is not None
return os.path.join(self.dataset_root, path)
def _local_path(self, path: str) -> str:
......
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# This functionality requires SQLAlchemy 2.0 or later.
import math
import struct
from typing import Optional, Tuple
import numpy as np
from pytorch3d.implicitron.dataset.types import (
DepthAnnotation,
ImageAnnotation,
MaskAnnotation,
PointCloudAnnotation,
VideoAnnotation,
ViewpointAnnotation,
)
from sqlalchemy import LargeBinary
from sqlalchemy.orm import (
composite,
DeclarativeBase,
Mapped,
mapped_column,
MappedAsDataclass,
)
from sqlalchemy.types import TypeDecorator
# these produce policies to serialize structured types to blobs
def ArrayTypeFactory(shape):
class NumpyArrayType(TypeDecorator):
impl = LargeBinary
def process_bind_param(self, value, dialect):
if value is not None:
if value.shape != shape:
raise ValueError(f"Passed an array of wrong shape: {value.shape}")
return value.astype(np.float32).tobytes()
return None
def process_result_value(self, value, dialect):
if value is not None:
return np.frombuffer(value, dtype=np.float32).reshape(shape)
return None
return NumpyArrayType
def TupleTypeFactory(dtype=float, shape: Tuple[int, ...] = (2,)):
format_symbol = {
float: "f", # float32
int: "i", # int32
}[dtype]
class TupleType(TypeDecorator):
impl = LargeBinary
_format = format_symbol * math.prod(shape)
def process_bind_param(self, value, _):
if value is None:
return None
if len(shape) > 1:
value = np.array(value, dtype=dtype).reshape(-1)
return struct.pack(TupleType._format, *value)
def process_result_value(self, value, _):
if value is None:
return None
loaded = struct.unpack(TupleType._format, value)
if len(shape) > 1:
loaded = _rec_totuple(
np.array(loaded, dtype=dtype).reshape(shape).tolist()
)
return loaded
return TupleType
def _rec_totuple(t):
if isinstance(t, list):
return tuple(_rec_totuple(x) for x in t)
return t
class Base(MappedAsDataclass, DeclarativeBase):
"""subclasses will be converted to dataclasses"""
class SqlFrameAnnotation(Base):
__tablename__ = "frame_annots"
sequence_name: Mapped[str] = mapped_column(primary_key=True)
frame_number: Mapped[int] = mapped_column(primary_key=True)
frame_timestamp: Mapped[float] = mapped_column(index=True)
image: Mapped[ImageAnnotation] = composite(
mapped_column("_image_path"),
mapped_column("_image_size", TupleTypeFactory(int)),
)
depth: Mapped[DepthAnnotation] = composite(
mapped_column("_depth_path", nullable=True),
mapped_column("_depth_scale_adjustment", nullable=True),
mapped_column("_depth_mask_path", nullable=True),
)
mask: Mapped[MaskAnnotation] = composite(
mapped_column("_mask_path", nullable=True),
mapped_column("_mask_mass", index=True, nullable=True),
mapped_column(
"_mask_bounding_box_xywh",
TupleTypeFactory(float, shape=(4,)),
nullable=True,
),
)
viewpoint: Mapped[ViewpointAnnotation] = composite(
mapped_column(
"_viewpoint_R", TupleTypeFactory(float, shape=(3, 3)), nullable=True
),
mapped_column(
"_viewpoint_T", TupleTypeFactory(float, shape=(3,)), nullable=True
),
mapped_column(
"_viewpoint_focal_length", TupleTypeFactory(float), nullable=True
),
mapped_column(
"_viewpoint_principal_point", TupleTypeFactory(float), nullable=True
),
mapped_column("_viewpoint_intrinsics_format", nullable=True),
)
class SqlSequenceAnnotation(Base):
__tablename__ = "sequence_annots"
sequence_name: Mapped[str] = mapped_column(primary_key=True)
category: Mapped[str] = mapped_column(index=True)
video: Mapped[VideoAnnotation] = composite(
mapped_column("_video_path", nullable=True),
mapped_column("_video_length", nullable=True),
)
point_cloud: Mapped[PointCloudAnnotation] = composite(
mapped_column("_point_cloud_path", nullable=True),
mapped_column("_point_cloud_quality_score", nullable=True),
mapped_column("_point_cloud_n_points", nullable=True),
)
# the bigger the better
viewpoint_quality_score: Mapped[Optional[float]] = mapped_column(default=None)
This diff is collapsed.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from typing import List, Optional, Tuple, Type
import numpy as np
from omegaconf import DictConfig, OmegaConf
from pytorch3d.implicitron.dataset.dataset_map_provider import (
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
)
from pytorch3d.implicitron.tools.config import (
expand_args_fields,
registry,
run_auto_creation,
)
from .sql_dataset import SqlIndexDataset
_CO3D_SQL_DATASET_ROOT: str = os.getenv("CO3D_SQL_DATASET_ROOT", "")
# _NEED_CONTROL is a list of those elements of SqlIndexDataset which
# are not directly specified for it in the config but come from the
# DatasetMapProvider.
_NEED_CONTROL: Tuple[str, ...] = (
"path_manager",
"subsets",
"sqlite_metadata_file",
"subset_lists_file",
)
logger = logging.getLogger(__name__)
@registry.register
class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
"""
Generates the training, validation, and testing dataset objects for
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
The dataset is organized in the filesystem as follows::
self.dataset_root
├── <possible/partition/0>
│ ├── <sequence_name_0>
│ │ ├── depth_masks
│ │ ├── depths
│ │ ├── images
│ │ ├── masks
│ │ └── pointcloud.ply
│ ├── <sequence_name_1>
│ │ ├── depth_masks
│ │ ├── depths
│ │ ├── images
│ │ ├── masks
│ │ └── pointcloud.ply
│ ├── ...
│ ├── <sequence_name_N>
│ ├── set_lists
│ ├── <subset_base_name_0>.json
│ ├── <subset_base_name_1>.json
│ ├── ...
│ ├── <subset_base_name_2>.json
│ ├── eval_batches
│ │ ├── <eval_batches_base_name_0>.json
│ │ ├── <eval_batches_base_name_1>.json
│ │ ├── ...
│ │ ├── <eval_batches_base_name_M>.json
│ ├── frame_annotations.jgz
│ ├── sequence_annotations.jgz
├── <possible/partition/1>
├── ...
├── <possible/partition/K>
├── set_lists
├── <subset_base_name_0>.sqlite
├── <subset_base_name_1>.sqlite
├── ...
├── <subset_base_name_2>.sqlite
├── eval_batches
│ ├── <eval_batches_base_name_0>.json
│ ├── <eval_batches_base_name_1>.json
│ ├── ...
│ ├── <eval_batches_base_name_M>.json
The dataset contains sequences named `<sequence_name_i>` that may be partitioned by
directories such as `<possible/partition/0>` e.g. representing categories but they
can also be stored in a flat structure. Each sequence folder contains the list of
sequence images, depth maps, foreground masks, and valid-depth masks
`images`, `depths`, `masks`, and `depth_masks` respectively. Furthermore,
`set_lists/` dirtectories (with partitions or global) store json or sqlite files
`<subset_base_name_l>.<ext>`, each describing a certain sequence subset.
These subset path conventions are not hard-coded and arbitrary relative path can be
specified by setting `self.subset_lists_path` to the relative path w.r.t.
dataset root.
Each `<subset_base_name_l>.json` file contains the following dictionary::
{
"train": [
(sequence_name: str, frame_number: int, image_path: str),
...
],
"val": [
(sequence_name: str, frame_number: int, image_path: str),
...
],
"test": [
(sequence_name: str, frame_number: int, image_path: str),
...
],
]
defining the list of frames (identified with their `sequence_name` and
`frame_number`) in the "train", "val", and "test" subsets of the dataset. In case of
SQLite format, `<subset_base_name_l>.sqlite` contains a table with the header::
| sequence_name | frame_number | image_path | subset |
Note that `frame_number` can be obtained only from the metadata and
does not necesarrily correspond to the numeric suffix of the corresponding image
file name (e.g. a file `<partition_0>/<sequence_name_0>/images/frame00005.jpg` can
have its frame number set to `20`, not 5).
Each `<eval_batches_base_name_M>.json` file contains a list of evaluation examples
in the following form::
[
[ # batch 1
(sequence_name: str, frame_number: int, image_path: str),
...
],
[ # batch 2
(sequence_name: str, frame_number: int, image_path: str),
...
],
]
Note that the evaluation examples always come from the `"test"` subset of the dataset.
(test frames can repeat across batches). The batches can contain single element,
which is typical in case of regular radiance field fitting.
Args:
subset_lists_path: The relative path to the dataset subset definition.
For CO3D, these include e.g. "skateboard/set_lists/set_lists_manyview_dev_0.json".
By default (None), dataset is not partitioned to subsets (in that case, setting
`ignore_subsets` will speed up construction)
dataset_root: The root folder of the dataset.
metadata_basename: name of the SQL metadata file in dataset_root;
not expected to be changed by users
test_on_train: Construct validation and test datasets from
the training subset; note that in practice, in this
case all subset dataset objects will be same
only_test_set: Load only the test set. Incompatible with `test_on_train`.
ignore_subsets: Don’t filter by subsets in the dataset; note that in this
case all subset datasets will be same
eval_batch_num_training_frames: Add a certain number of training frames to each
eval batch. Useful for evaluating models that require
source views as input (e.g. NeRF-WCE / PixelNeRF).
dataset_args: Specifies additional arguments to the
JsonIndexDataset constructor call.
path_manager_factory: (Optional) An object that generates an instance of
PathManager that can translate provided file paths.
path_manager_factory_class_type: The class type of `path_manager_factory`.
"""
category: Optional[str] = None
subset_list_name: Optional[str] = None # TODO: docs
# OR
subset_lists_path: Optional[str] = None
eval_batches_path: Optional[str] = None
dataset_root: str = _CO3D_SQL_DATASET_ROOT
metadata_basename: str = "metadata.sqlite"
test_on_train: bool = False
only_test_set: bool = False
ignore_subsets: bool = False
train_subsets: Tuple[str, ...] = ("train",)
val_subsets: Tuple[str, ...] = ("val",)
test_subsets: Tuple[str, ...] = ("test",)
eval_batch_num_training_frames: int = 0
# this is a mould that is never constructed, used to build self._dataset_map values
dataset_class_type: str = "SqlIndexDataset"
dataset: SqlIndexDataset
path_manager_factory: PathManagerFactory
path_manager_factory_class_type: str = "PathManagerFactory"
def __post_init__(self):
super().__init__()
run_auto_creation(self)
if self.only_test_set and self.test_on_train:
raise ValueError("Cannot have only_test_set and test_on_train")
if self.ignore_subsets and not self.only_test_set:
self.test_on_train = True # no point in loading same data 3 times
path_manager = self.path_manager_factory.get()
sqlite_metadata_file = os.path.join(self.dataset_root, self.metadata_basename)
sqlite_metadata_file = _local_path(path_manager, sqlite_metadata_file)
if not os.path.isfile(sqlite_metadata_file):
# The sqlite_metadata_file does not exist.
# Most probably the user has not specified the root folder.
raise ValueError(
f"Looking for frame annotations in {sqlite_metadata_file}."
+ " Please specify a correct dataset_root folder."
+ " Note: By default the root folder is taken from the"
+ " CO3D_SQL_DATASET_ROOT environment variable."
)
if self.subset_lists_path and self.subset_list_name:
raise ValueError(
"subset_lists_path and subset_list_name cannot be both set"
)
subset_lists_file = self._get_lists_file("set_lists")
# setup the common dataset arguments
common_dataset_kwargs = {
**getattr(self, f"dataset_{self.dataset_class_type}_args"),
"sqlite_metadata_file": sqlite_metadata_file,
"dataset_root": self.dataset_root,
"subset_lists_file": subset_lists_file,
"path_manager": path_manager,
}
if self.category:
logger.info(f"Forcing category filter in the datasets to {self.category}")
common_dataset_kwargs["pick_categories"] = self.category.split(",")
# get the used dataset type
dataset_type: Type[SqlIndexDataset] = registry.get(
SqlIndexDataset, self.dataset_class_type
)
expand_args_fields(dataset_type)
if subset_lists_file is not None and not os.path.isfile(subset_lists_file):
available_subsets = self._get_available_subsets(
OmegaConf.to_object(common_dataset_kwargs["pick_categories"])
)
msg = f"Cannot find subset list file {self.subset_lists_path}."
if available_subsets:
msg += f" Some of the available subsets: {str(available_subsets)}."
raise ValueError(msg)
train_dataset = None
val_dataset = None
if not self.only_test_set:
# load the training set
logger.debug("Constructing train dataset.")
train_dataset = dataset_type(
**common_dataset_kwargs, subsets=self._get_subsets(self.train_subsets)
)
logger.info(f"Train dataset: {str(train_dataset)}")
if self.test_on_train:
assert train_dataset is not None
val_dataset = test_dataset = train_dataset
else:
# load the val and test sets
if not self.only_test_set:
# NOTE: this is always loaded in JsonProviderV2
logger.debug("Extracting val dataset.")
val_dataset = dataset_type(
**common_dataset_kwargs, subsets=self._get_subsets(self.val_subsets)
)
logger.info(f"Val dataset: {str(val_dataset)}")
logger.debug("Extracting test dataset.")
eval_batches_file = self._get_lists_file("eval_batches")
del common_dataset_kwargs["eval_batches_file"]
test_dataset = dataset_type(
**common_dataset_kwargs,
subsets=self._get_subsets(self.test_subsets, True),
eval_batches_file=eval_batches_file,
)
logger.info(f"Test dataset: {str(test_dataset)}")
if (
eval_batches_file is not None
and self.eval_batch_num_training_frames > 0
):
self._extend_eval_batches(test_dataset)
self._dataset_map = DatasetMap(
train=train_dataset, val=val_dataset, test=test_dataset
)
def _get_subsets(self, subsets, is_eval: bool = False):
if self.ignore_subsets:
return None
if is_eval and self.eval_batch_num_training_frames > 0:
# we will need to have training frames for extended batches
return list(subsets) + list(self.train_subsets)
return subsets
def _extend_eval_batches(self, test_dataset: SqlIndexDataset) -> None:
rng = np.random.default_rng(seed=0)
eval_batches = test_dataset.get_eval_batches()
if eval_batches is None:
raise ValueError("Eval batches were not loaded!")
for batch in eval_batches:
sequence = batch[0][0]
seq_frames = list(
test_dataset.sequence_frames_in_order(sequence, self.train_subsets)
)
idx_to_add = rng.permutation(len(seq_frames))[
: self.eval_batch_num_training_frames
]
batch.extend((sequence, seq_frames[a][1]) for a in idx_to_add)
@classmethod
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
"""
Called by get_default_args.
Certain fields are not exposed on each dataset class
but rather are controlled by this provider class.
"""
for key in _NEED_CONTROL:
del args[key]
def create_dataset(self):
# No `dataset` member of this class is created.
# The dataset(s) live in `self.get_dataset_map`.
pass
def get_dataset_map(self) -> DatasetMap:
return self._dataset_map # pyre-ignore [16]
def _get_available_subsets(self, categories: List[str]):
"""
Get the available subset names for a given category folder (if given) inside
a root dataset folder `dataset_root`.
"""
path_manager = self.path_manager_factory.get()
subsets: List[str] = []
for prefix in [""] + categories:
set_list_dir = os.path.join(self.dataset_root, prefix, "set_lists")
if not (
(path_manager is not None) and path_manager.isdir(set_list_dir)
) and not os.path.isdir(set_list_dir):
continue
set_list_files = (os.listdir if path_manager is None else path_manager.ls)(
set_list_dir
)
subsets.extend(os.path.join(prefix, "set_lists", f) for f in set_list_files)
return subsets
def _get_lists_file(self, flavor: str) -> Optional[str]:
if flavor == "eval_batches":
subset_lists_path = self.eval_batches_path
else:
subset_lists_path = self.subset_lists_path
if not subset_lists_path and not self.subset_list_name:
return None
category_elem = ""
if self.category and "," not in self.category:
# if multiple categories are given, looking for global set lists
category_elem = self.category
subset_lists_path = subset_lists_path or (
os.path.join(
category_elem, f"{flavor}", f"{flavor}_{self.subset_list_name}"
)
)
assert subset_lists_path
path_manager = self.path_manager_factory.get()
# try absolute path first
subset_lists_file = _get_local_path_check_extensions(
subset_lists_path, path_manager
)
if subset_lists_file:
return subset_lists_file
full_path = os.path.join(self.dataset_root, subset_lists_path)
subset_lists_file = _get_local_path_check_extensions(full_path, path_manager)
if not subset_lists_file:
raise FileNotFoundError(
f"Subset lists path given but not found: {full_path}"
)
return subset_lists_file
def _get_local_path_check_extensions(
path, path_manager, extensions=("", ".sqlite", ".json")
) -> Optional[str]:
for ext in extensions:
local = _local_path(path_manager, path + ext)
if os.path.isfile(local):
return local
return None
def _local_path(path_manager, path: str) -> str:
if path_manager is None:
return path
return path_manager.get_local_path(path)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, Optional, Tuple
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
DataLoaderMap,
SceneBatchSampler,
SequenceDataLoaderMapProvider,
)
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.dataset.frame_data import FrameData
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
from torch.utils.data import DataLoader
logger = logging.getLogger(__name__)
# TODO: we can merge it with SequenceDataLoaderMapProvider in PyTorch3D
# and support both eval_batches protocols
@registry.register
class TrainEvalDataLoaderMapProvider(SequenceDataLoaderMapProvider):
"""
Implementation of DataLoaderMapProviderBase that may use internal eval batches for
the test dataset. In particular, if `eval_batches_relpath` is set, it loads
eval batches from that json file, otherwise test set is treated in the same way as
train and val, i.e. the parameters `dataset_length_test` and `test_conditioning_type`
are respected.
If conditioning is not required, then the batch size should
be set as 1, and most of the fields do not matter.
If conditioning is required, each batch will contain one main
frame first to predict and the, rest of the elements are for
conditioning.
If images_per_seq_options is left empty, the conditioning
frames are picked according to the conditioning type given.
This does not have regard to the order of frames in a
scene, or which frames belong to what scene.
If images_per_seq_options is given, then the conditioning types
must be SAME and the remaining fields are used.
Members:
batch_size: The size of the batch of the data loader.
num_workers: Number of data-loading threads in each data loader.
dataset_length_train: The number of batches in a training epoch. Or 0 to mean
an epoch is the length of the training set.
dataset_length_val: The number of batches in a validation epoch. Or 0 to mean
an epoch is the length of the validation set.
dataset_length_test: used if test_dataset.eval_batches is NOT set. The number of
batches in a testing epoch. Or 0 to mean an epoch is the length of the test
set.
images_per_seq_options: Possible numbers of frames sampled per sequence in a batch.
If a conditioning_type is KNOWN or TRAIN, then this must be left at its initial
value. Empty (the default) means that we are not careful about which frames
come from which scene.
sample_consecutive_frames: if True, will sample a contiguous interval of frames
in the sequence. It first sorts the frames by timestimps when available,
otherwise by frame numbers, finds the connected segments within the sequence
of sufficient length, then samples a random pivot element among them and
ideally uses it as a middle of the temporal window, shifting the borders
where necessary. This strategy mitigates the bias against shorter segments
and their boundaries.
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
difference in frame_number of neighbouring frames when forming connected
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
the whole sequence is considered a segment regardless of frame numbers.
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
maximum difference in frame_timestamp of neighbouring frames when forming
connected segments; if both this and consecutive_frames_max_gap are 0s,
the whole sequence is considered a segment regardless of frame timestamps.
"""
batch_size: int = 1
num_workers: int = 0
dataset_length_train: int = 0
dataset_length_val: int = 0
dataset_length_test: int = 0
images_per_seq_options: Tuple[int, ...] = ()
sample_consecutive_frames: bool = False
consecutive_frames_max_gap: int = 0
consecutive_frames_max_gap_seconds: float = 0.1
def __post_init__(self):
run_auto_creation(self)
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
"""
Returns a collection of data loaders for a given collection of datasets.
"""
train = self._make_generic_data_loader(
datasets.train,
self.dataset_length_train,
datasets.train,
)
val = self._make_generic_data_loader(
datasets.val,
self.dataset_length_val,
datasets.train,
)
if datasets.test is not None and datasets.test.get_eval_batches() is not None:
test = self._make_eval_data_loader(datasets.test)
else:
test = self._make_generic_data_loader(
datasets.test,
self.dataset_length_test,
datasets.train,
)
return DataLoaderMap(train=train, val=val, test=test)
def _make_eval_data_loader(
self,
dataset: Optional[DatasetBase],
) -> Optional[DataLoader[FrameData]]:
if dataset is None:
return None
return DataLoader(
dataset,
batch_sampler=dataset.get_eval_batches(),
**self._get_data_loader_common_kwargs(dataset),
)
def _make_generic_data_loader(
self,
dataset: Optional[DatasetBase],
num_batches: int,
train_dataset: Optional[DatasetBase],
) -> Optional[DataLoader[FrameData]]:
"""
Returns the dataloader for a dataset.
Args:
dataset: the dataset
num_batches: possible ceiling on number of batches per epoch
train_dataset: the training dataset, used if conditioning_type==TRAIN
conditioning_type: source for padding of batches
"""
if dataset is None:
return None
data_loader_kwargs = self._get_data_loader_common_kwargs(dataset)
if len(self.images_per_seq_options) > 0:
# this is a typical few-view setup
# conditioning comes from the same subset since subsets are split by seqs
batch_sampler = SceneBatchSampler(
dataset,
self.batch_size,
num_batches=len(dataset) if num_batches <= 0 else num_batches,
images_per_seq_options=self.images_per_seq_options,
sample_consecutive_frames=self.sample_consecutive_frames,
consecutive_frames_max_gap=self.consecutive_frames_max_gap,
consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds,
)
return DataLoader(
dataset,
batch_sampler=batch_sampler,
**data_loader_kwargs,
)
if self.batch_size == 1:
# this is a typical many-view setup (without conditioning)
return self._simple_loader(dataset, num_batches, data_loader_kwargs)
# edge case: conditioning on train subset, typical for Nerformer-like many-view
# there is only one sequence in all datasets, so we condition on another subset
return self._train_loader(
dataset, train_dataset, num_batches, data_loader_kwargs
)
def _get_data_loader_common_kwargs(self, dataset: DatasetBase) -> Dict[str, Any]:
return {
"num_workers": self.num_workers,
"collate_fn": dataset.frame_data_type.collate,
}
......@@ -164,6 +164,7 @@ setup(
"tqdm>4.29.0",
"matplotlib",
"accelerate",
"sqlalchemy>=2.0",
],
},
entry_points={
......
{"train": [["cat0_seq0", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000000.jpg"], ["cat0_seq0", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000002.jpg"], ["cat0_seq0", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000004.jpg"], ["cat0_seq0", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000006.jpg"], ["cat0_seq0", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000008.jpg"], ["cat0_seq1", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000000.jpg"], ["cat0_seq1", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000002.jpg"], ["cat0_seq1", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000004.jpg"], ["cat0_seq1", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000006.jpg"], ["cat0_seq1", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000008.jpg"], ["cat0_seq2", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000000.jpg"], ["cat0_seq2", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000002.jpg"], ["cat0_seq2", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000004.jpg"], ["cat0_seq2", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000006.jpg"], ["cat0_seq2", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000008.jpg"], ["cat0_seq3", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000000.jpg"], ["cat0_seq3", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000002.jpg"], ["cat0_seq3", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000004.jpg"], ["cat0_seq3", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000006.jpg"], ["cat0_seq3", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000008.jpg"], ["cat0_seq4", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000000.jpg"], ["cat0_seq4", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000002.jpg"], ["cat0_seq4", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000004.jpg"], ["cat0_seq4", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000006.jpg"], ["cat0_seq4", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000008.jpg"], ["cat1_seq0", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000000.jpg"], ["cat1_seq0", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000002.jpg"], ["cat1_seq0", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000004.jpg"], ["cat1_seq0", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000006.jpg"], ["cat1_seq0", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000008.jpg"], ["cat1_seq1", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000000.jpg"], ["cat1_seq1", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000002.jpg"], ["cat1_seq1", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000004.jpg"], ["cat1_seq1", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000006.jpg"], ["cat1_seq1", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000008.jpg"], ["cat1_seq2", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000000.jpg"], ["cat1_seq2", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000002.jpg"], ["cat1_seq2", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000004.jpg"], ["cat1_seq2", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000006.jpg"], ["cat1_seq2", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000008.jpg"], ["cat1_seq3", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000000.jpg"], ["cat1_seq3", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000002.jpg"], ["cat1_seq3", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000004.jpg"], ["cat1_seq3", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000006.jpg"], ["cat1_seq3", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000008.jpg"], ["cat1_seq4", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000000.jpg"], ["cat1_seq4", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000002.jpg"], ["cat1_seq4", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000004.jpg"], ["cat1_seq4", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000006.jpg"], ["cat1_seq4", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000008.jpg"]], "test": [["cat0_seq0", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000001.jpg"], ["cat0_seq0", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000003.jpg"], ["cat0_seq0", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000005.jpg"], ["cat0_seq0", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000007.jpg"], ["cat0_seq0", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000009.jpg"], ["cat0_seq1", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000001.jpg"], ["cat0_seq1", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000003.jpg"], ["cat0_seq1", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000005.jpg"], ["cat0_seq1", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000007.jpg"], ["cat0_seq1", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000009.jpg"], ["cat0_seq2", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000001.jpg"], ["cat0_seq2", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000003.jpg"], ["cat0_seq2", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000005.jpg"], ["cat0_seq2", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000007.jpg"], ["cat0_seq2", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000009.jpg"], ["cat0_seq3", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000001.jpg"], ["cat0_seq3", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000003.jpg"], ["cat0_seq3", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000005.jpg"], ["cat0_seq3", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000007.jpg"], ["cat0_seq3", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000009.jpg"], ["cat0_seq4", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000001.jpg"], ["cat0_seq4", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000003.jpg"], ["cat0_seq4", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000005.jpg"], ["cat0_seq4", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000007.jpg"], ["cat0_seq4", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000009.jpg"], ["cat1_seq0", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000001.jpg"], ["cat1_seq0", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000003.jpg"], ["cat1_seq0", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000005.jpg"], ["cat1_seq0", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000007.jpg"], ["cat1_seq0", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000009.jpg"], ["cat1_seq1", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000001.jpg"], ["cat1_seq1", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000003.jpg"], ["cat1_seq1", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000005.jpg"], ["cat1_seq1", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000007.jpg"], ["cat1_seq1", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000009.jpg"], ["cat1_seq2", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000001.jpg"], ["cat1_seq2", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000003.jpg"], ["cat1_seq2", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000005.jpg"], ["cat1_seq2", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000007.jpg"], ["cat1_seq2", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000009.jpg"], ["cat1_seq3", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000001.jpg"], ["cat1_seq3", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000003.jpg"], ["cat1_seq3", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000005.jpg"], ["cat1_seq3", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000007.jpg"], ["cat1_seq3", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000009.jpg"], ["cat1_seq4", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000001.jpg"], ["cat1_seq4", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000003.jpg"], ["cat1_seq4", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000005.jpg"], ["cat1_seq4", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000007.jpg"], ["cat1_seq4", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000009.jpg"]]}
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import unittest
import torch
from pytorch3d.implicitron.dataset.data_loader_map_provider import ( # noqa
SequenceDataLoaderMapProvider,
SimpleDataLoaderMapProvider,
)
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset # noqa
from pytorch3d.implicitron.dataset.sql_dataset_provider import ( # noqa
SqlIndexDatasetMapProvider,
)
from pytorch3d.implicitron.dataset.train_eval_data_loader_provider import (
TrainEvalDataLoaderMapProvider,
)
from pytorch3d.implicitron.tools.config import get_default_args
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
sh = logging.StreamHandler()
logger.addHandler(sh)
logger.setLevel(logging.DEBUG)
_CO3D_SQL_DATASET_ROOT: str = os.getenv("CO3D_SQL_DATASET_ROOT", "")
@unittest.skipUnless(_CO3D_SQL_DATASET_ROOT, "Run only if CO3D is available")
class TestCo3dSqlDataSource(unittest.TestCase):
def test_no_subsets(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.ignore_subsets = True
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.pick_categories = ["skateboard"]
dataset_args.limit_sequences_to = 1
data_source = ImplicitronDataSource(**args)
self.assertIsInstance(
data_source.data_loader_map_provider, TrainEvalDataLoaderMapProvider
)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 202)
for frame in data_loaders.train:
self.assertIsNone(frame.frame_type)
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
break
def test_subsets(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = (
"skateboard/set_lists/set_lists_manyview_dev_0.json"
)
# this will naturally limit to one sequence (no need to limit by cat/sequence)
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
for sampler_type in [
"SimpleDataLoaderMapProvider",
"SequenceDataLoaderMapProvider",
"TrainEvalDataLoaderMapProvider",
]:
args.data_loader_map_provider_class_type = sampler_type
data_source = ImplicitronDataSource(**args)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 102)
self.assertEqual(len(data_loaders.val), 100)
self.assertEqual(len(data_loaders.test), 100)
for split in ["train", "val", "test"]:
for frame in data_loaders[split]:
self.assertEqual(frame.frame_type, [split])
# check loading blobs
self.assertEqual(frame.image_rgb.shape[-1], 800)
break
def test_sql_subsets(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
dataset_args.pick_categories = ["skateboard"]
for sampler_type in [
"SimpleDataLoaderMapProvider",
"SequenceDataLoaderMapProvider",
"TrainEvalDataLoaderMapProvider",
]:
args.data_loader_map_provider_class_type = sampler_type
data_source = ImplicitronDataSource(**args)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 102)
self.assertEqual(len(data_loaders.val), 100)
self.assertEqual(len(data_loaders.test), 100)
for split in ["train", "val", "test"]:
for frame in data_loaders[split]:
self.assertEqual(frame.frame_type, [split])
self.assertEqual(
frame.image_rgb.shape[-1], 800
) # check loading blobs
break
@unittest.skip("It takes 75 seconds; skipping by default")
def test_huge_subsets(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = "set_lists/set_lists_fewview_dev.sqlite"
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
data_source = ImplicitronDataSource(**args)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 3158974)
self.assertEqual(len(data_loaders.val), 518417)
self.assertEqual(len(data_loaders.test), 518417)
for split in ["train", "val", "test"]:
for frame in data_loaders[split]:
self.assertEqual(frame.frame_type, [split])
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
break
def test_broken_subsets(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = "et_non_est"
provider_args.dataset_SqlIndexDataset_args.pick_categories = ["skateboard"]
with self.assertRaises(FileNotFoundError) as err:
ImplicitronDataSource(**args)
# check the hint text
self.assertIn("Subset lists path given but not found", str(err.exception))
def test_eval_batches(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
provider_args.eval_batches_path = (
"skateboard/eval_batches/eval_batches_manyview_dev_0.json"
)
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
dataset_args.pick_categories = ["skateboard"]
data_source = ImplicitronDataSource(**args)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 102)
self.assertEqual(len(data_loaders.val), 100)
self.assertEqual(len(data_loaders.test), 50)
for split in ["train", "val", "test"]:
for frame in data_loaders[split]:
self.assertEqual(frame.frame_type, [split])
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
break
def test_eval_batches_from_subset_list_name(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_list_name = "manyview_dev_0"
provider_args.category = "skateboard"
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
data_source = ImplicitronDataSource(**args)
dataset, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertListEqual(list(dataset.train.pick_categories), ["skateboard"])
self.assertEqual(len(data_loaders.train), 102)
self.assertEqual(len(data_loaders.val), 100)
self.assertEqual(len(data_loaders.test), 50)
for split in ["train", "val", "test"]:
for frame in data_loaders[split]:
self.assertEqual(frame.frame_type, [split])
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
break
def test_frame_access(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
dataset_args.pick_categories = ["skateboard"]
frame_builder_args = dataset_args.frame_data_builder_FrameDataBuilder_args
frame_builder_args.load_point_clouds = True
frame_builder_args.box_crop = False # required for .meta
data_source = ImplicitronDataSource(**args)
dataset_map, _ = data_source.get_datasets_and_dataloaders()
dataset = dataset_map["train"]
for idx in [10, ("245_26182_52130", 22)]:
example_meta = dataset.meta[idx]
example = dataset[idx]
self.assertIsNone(example_meta.image_rgb)
self.assertIsNone(example_meta.fg_probability)
self.assertIsNone(example_meta.depth_map)
self.assertIsNone(example_meta.sequence_point_cloud)
self.assertIsNotNone(example_meta.camera)
self.assertIsNotNone(example.image_rgb)
self.assertIsNotNone(example.fg_probability)
self.assertIsNotNone(example.depth_map)
self.assertIsNotNone(example.sequence_point_cloud)
self.assertIsNotNone(example.camera)
self.assertEqual(example_meta.sequence_name, example.sequence_name)
self.assertEqual(example_meta.frame_number, example.frame_number)
self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp)
self.assertEqual(example_meta.sequence_category, example.sequence_category)
torch.testing.assert_close(example_meta.camera.R, example.camera.R)
torch.testing.assert_close(example_meta.camera.T, example.camera.T)
torch.testing.assert_close(
example_meta.camera.focal_length, example.camera.focal_length
)
torch.testing.assert_close(
example_meta.camera.principal_point, example.camera.principal_point
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import unittest
from collections import Counter
import pkg_resources
import torch
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset
NO_BLOBS_KWARGS = {
"dataset_root": "",
"load_images": False,
"load_depths": False,
"load_masks": False,
"load_depth_masks": False,
"box_crop": False,
}
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
sh = logging.StreamHandler()
logger.addHandler(sh)
logger.setLevel(logging.DEBUG)
DATASET_ROOT = pkg_resources.resource_filename(__name__, "data/sql_dataset")
METADATA_FILE = os.path.join(DATASET_ROOT, "sql_dataset_100.sqlite")
SET_LIST_FILE = os.path.join(DATASET_ROOT, "set_lists_100.json")
class TestSqlDataset(unittest.TestCase):
def test_basic(self, sequence="cat1_seq2", frame_number=4):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 100)
# check the items are consecutive
past_sequences = set()
last_frame_number = -1
last_sequence = ""
for i in range(len(dataset)):
item = dataset[i]
if item.frame_number == 0:
self.assertNotIn(item.sequence_name, past_sequences)
past_sequences.add(item.sequence_name)
last_sequence = item.sequence_name
else:
self.assertEqual(item.sequence_name, last_sequence)
self.assertEqual(item.frame_number, last_frame_number + 1)
last_frame_number = item.frame_number
# test indexing
with self.assertRaises(IndexError):
dataset[len(dataset) + 1]
# test sequence-frame indexing
item = dataset[sequence, frame_number]
self.assertEqual(item.sequence_name, sequence)
self.assertEqual(item.frame_number, frame_number)
with self.assertRaises(IndexError):
dataset[sequence, 13]
def test_filter_empty_masks(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 78)
def test_pick_frames_sql_clause(self):
dataset_no_empty_masks = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
pick_frames_sql_clause="_mask_mass IS NULL OR _mask_mass > 0",
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
# check the datasets are equal
self.assertEqual(len(dataset), len(dataset_no_empty_masks))
for i in range(len(dataset)):
item_nem = dataset_no_empty_masks[i]
item = dataset[i]
self.assertEqual(item_nem.image_path, item.image_path)
# remove_empty_masks together with the custom criterion
dataset_ts = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
pick_frames_sql_clause="frame_timestamp < 0.15",
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset_ts), 19)
def test_limit_categories(self, category="cat0"):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
pick_categories=[category],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 50)
for i in range(len(dataset)):
self.assertEqual(dataset[i].sequence_category, category)
def test_limit_sequences(self, num_sequences=3):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
limit_sequences_to=num_sequences,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 10 * num_sequences)
def delist(sequence_name):
return sequence_name if isinstance(sequence_name, str) else sequence_name[0]
unique_seqs = {delist(dataset[i].sequence_name) for i in range(len(dataset))}
self.assertEqual(len(unique_seqs), num_sequences)
def test_pick_exclude_sequencess(self, sequence="cat1_seq2"):
# pick sequence
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
pick_sequences=[sequence],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 10)
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
self.assertCountEqual(unique_seqs, {sequence})
item = dataset[sequence, 0]
self.assertEqual(item.sequence_name, sequence)
self.assertEqual(item.frame_number, 0)
# exclude sequence
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
exclude_sequences=[sequence],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 90)
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
self.assertNotIn(sequence, unique_seqs)
with self.assertRaises(IndexError):
dataset[sequence, 0]
def test_limit_frames(self, num_frames=13):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
limit_to=num_frames,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), num_frames)
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
self.assertEqual(len(unique_seqs), 2)
# test when the limit is not binding
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
limit_to=1000,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 100)
def test_limit_frames_per_sequence(self, num_frames=2):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
n_frames_per_sequence=num_frames,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), num_frames * 10)
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
self.assertEqual(len(seq_counts), 10)
self.assertCountEqual(
set(seq_counts.values()), {2}
) # all counts are num_frames
with self.assertRaises(IndexError):
dataset[next(iter(seq_counts)), num_frames + 1]
# test when the limit is not binding
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
n_frames_per_sequence=13,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 100)
def test_filter_medley(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
pick_categories=["cat1"],
exclude_sequences=["cat1_seq0"], # retaining "cat1_seq1" and on
limit_sequences_to=2, # retaining "cat1_seq1" and "cat1_seq2"
limit_to=14, # retaining full "cat1_seq1" and 4 from "cat1_seq2"
n_frames_per_sequence=6, # cutting "cat1_seq1" to 6 frames
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
# result: preserved 6 frames from cat1_seq1 and 4 from cat1_seq2
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
self.assertCountEqual(seq_counts.keys(), ["cat1_seq1", "cat1_seq2"])
self.assertEqual(seq_counts["cat1_seq1"], 6)
self.assertEqual(seq_counts["cat1_seq2"], 4)
def test_subsets_trivial(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
subset_lists_file=SET_LIST_FILE,
limit_to=100, # force sorting
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 100)
# check the items are consecutive
past_sequences = set()
last_frame_number = -1
last_sequence = ""
for i in range(len(dataset)):
item = dataset[i]
if item.frame_number == 0:
self.assertNotIn(item.sequence_name, past_sequences)
past_sequences.add(item.sequence_name)
last_sequence = item.sequence_name
else:
self.assertEqual(item.sequence_name, last_sequence)
self.assertEqual(item.frame_number, last_frame_number + 1)
last_frame_number = item.frame_number
def test_subsets_filter_empty_masks(self):
# we need to test this case as it uses quite different logic with `df.drop()`
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 78)
def test_subsets_pick_frames_sql_clause(self):
dataset_no_empty_masks = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
pick_frames_sql_clause="_mask_mass IS NULL OR _mask_mass > 0",
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
# check the datasets are equal
self.assertEqual(len(dataset), len(dataset_no_empty_masks))
for i in range(len(dataset)):
item_nem = dataset_no_empty_masks[i]
item = dataset[i]
self.assertEqual(item_nem.image_path, item.image_path)
# remove_empty_masks together with the custom criterion
dataset_ts = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
pick_frames_sql_clause="frame_timestamp < 0.15",
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset_ts), 19)
def test_single_subset(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
subset_lists_file=SET_LIST_FILE,
subsets=["train"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 50)
with self.assertRaises(IndexError):
dataset[51]
# check the items are consecutive
past_sequences = set()
last_frame_number = -1
last_sequence = ""
for i in range(len(dataset)):
item = dataset[i]
if item.frame_number < 2:
self.assertNotIn(item.sequence_name, past_sequences)
past_sequences.add(item.sequence_name)
last_sequence = item.sequence_name
else:
self.assertEqual(item.sequence_name, last_sequence)
self.assertEqual(item.frame_number, last_frame_number + 2)
last_frame_number = item.frame_number
item = dataset[last_sequence, 0]
self.assertEqual(item.sequence_name, last_sequence)
with self.assertRaises(IndexError):
dataset[last_sequence, 1]
def test_subset_with_filters(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
subset_lists_file=SET_LIST_FILE,
subsets=["train"],
pick_categories=["cat1"],
exclude_sequences=["cat1_seq0"], # retaining "cat1_seq1" and on
limit_sequences_to=2, # retaining "cat1_seq1" and "cat1_seq2"
limit_to=7, # retaining full train set of "cat1_seq1" and 2 from "cat1_seq2"
n_frames_per_sequence=3, # cutting "cat1_seq1" to 3 frames
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
# result: preserved 6 frames from cat1_seq1 and 4 from cat1_seq2
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
self.assertCountEqual(seq_counts.keys(), ["cat1_seq1", "cat1_seq2"])
self.assertEqual(seq_counts["cat1_seq1"], 3)
self.assertEqual(seq_counts["cat1_seq2"], 2)
def test_visitor(self):
dataset_sorted = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
sequences = dataset_sorted.sequence_names()
i = 0
for seq in sequences:
last_ts = float("-Inf")
for ts, _, idx in dataset_sorted.sequence_frames_in_order(seq):
self.assertEqual(i, idx)
i += 1
self.assertGreaterEqual(ts, last_ts)
last_ts = ts
# test legacy visitor
old_indices = None
for seq in sequences:
last_ts = float("-Inf")
rows = dataset_sorted._index.index.get_loc(seq)
indices = list(range(rows.start or 0, rows.stop, rows.step or 1))
fn_ts_list = dataset_sorted.get_frame_numbers_and_timestamps(indices)
self.assertEqual(len(fn_ts_list), len(indices))
if old_indices:
# check raising if we ask for multiple sequences
with self.assertRaises(ValueError):
dataset_sorted.get_frame_numbers_and_timestamps(
indices + old_indices
)
old_indices = indices
def test_visitor_subsets(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
limit_to=100, # force sorting
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
sequences = dataset.sequence_names()
i = 0
for seq in sequences:
last_ts = float("-Inf")
seq_frames = list(dataset.sequence_frames_in_order(seq))
self.assertEqual(len(seq_frames), 10)
for ts, _, idx in seq_frames:
self.assertEqual(i, idx)
i += 1
self.assertGreaterEqual(ts, last_ts)
last_ts = ts
last_ts = float("-Inf")
train_frames = list(dataset.sequence_frames_in_order(seq, "train"))
self.assertEqual(len(train_frames), 5)
for ts, _, _ in train_frames:
self.assertGreaterEqual(ts, last_ts)
last_ts = ts
def test_category_to_sequence_names(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
cat_to_seqs = dataset.category_to_sequence_names()
self.assertEqual(len(cat_to_seqs), 2)
self.assertIn("cat1", cat_to_seqs)
self.assertEqual(len(cat_to_seqs["cat1"]), 5)
# check that override preserves the behavior
cat_to_seqs_base = super(SqlIndexDataset, dataset).category_to_sequence_names()
self.assertDictEqual(cat_to_seqs, cat_to_seqs_base)
def test_category_to_sequence_names_filters(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
subset_lists_file=SET_LIST_FILE,
exclude_sequences=["cat1_seq0"],
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
cat_to_seqs = dataset.category_to_sequence_names()
self.assertEqual(len(cat_to_seqs), 2)
self.assertIn("cat1", cat_to_seqs)
self.assertEqual(len(cat_to_seqs["cat1"]), 4) # minus one
# check that override preserves the behavior
cat_to_seqs_base = super(SqlIndexDataset, dataset).category_to_sequence_names()
self.assertDictEqual(cat_to_seqs, cat_to_seqs_base)
def test_meta_access(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
subset_lists_file=SET_LIST_FILE,
subsets=["train"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 50)
for idx in [10, ("cat0_seq2", 2)]:
example_meta = dataset.meta[idx]
example = dataset[idx]
self.assertEqual(example_meta.sequence_name, example.sequence_name)
self.assertEqual(example_meta.frame_number, example.frame_number)
self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp)
self.assertEqual(example_meta.sequence_category, example.sequence_category)
torch.testing.assert_close(example_meta.camera.R, example.camera.R)
torch.testing.assert_close(example_meta.camera.T, example.camera.T)
torch.testing.assert_close(
example_meta.camera.focal_length, example.camera.focal_length
)
torch.testing.assert_close(
example_meta.camera.principal_point, example.camera.principal_point
)
def test_meta_access_no_blobs(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
subset_lists_file=SET_LIST_FILE,
subsets=["train"],
frame_data_builder_FrameDataBuilder_args={
"dataset_root": ".",
"box_crop": False, # required by blob-less accessor
},
)
self.assertIsNone(dataset.meta[0].image_rgb)
self.assertIsNone(dataset.meta[0].fg_probability)
self.assertIsNone(dataset.meta[0].depth_map)
self.assertIsNone(dataset.meta[0].sequence_point_cloud)
self.assertIsNotNone(dataset.meta[0].camera)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment