Commit 0f12c516 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

data_loader_map_provider

Summary: replace dataloader_zoo with a pluggable DataLoaderMapProvider.

Reviewed By: shapovalov

Differential Revision: D36475441

fbshipit-source-id: d16abb190d876940434329928f2e3f2794a25416
parent 79c61a2d
......@@ -237,7 +237,7 @@ generic_model_args: GenericModel
solver_args: init_optimizer
data_source_args: ImplicitronDataSource
└-- dataset_map_provider_*_args
└-- dataloader_args
└-- data_loader_map_provider_*_args
```
Please look at the annotations of the respective classes or functions for the lists of hyperparameters.
......
......@@ -7,7 +7,7 @@ visualize_interval: 0
visdom_port: 8097
data_source_args:
dataset_provider_class_type: JsonIndexDatasetMapProvider
dataloader_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
......
......@@ -2,7 +2,7 @@ defaults:
- repro_base.yaml
- _self_
data_source_args:
dataloader_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
......
......@@ -2,7 +2,7 @@ defaults:
- repro_base
- _self_
data_source_args:
dataloader_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1
dataset_len: 1000
dataset_len_val: 1
......
......@@ -2,7 +2,7 @@ defaults:
- repro_singleseq_base
- _self_
data_source_args:
dataloader_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
......
......@@ -64,8 +64,8 @@ import tqdm
from omegaconf import DictConfig, OmegaConf
from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
......@@ -553,7 +553,7 @@ def _eval_and_dump(
cfg,
task: Task,
datasets: DatasetMap,
dataloaders: Dataloaders,
dataloaders: DataLoaderMap,
model,
stats,
device,
......@@ -566,7 +566,7 @@ def _eval_and_dump(
dataloader = dataloaders.test
if dataloader is None:
raise ValueError('Dataloaders have to contain the "test" entry for eval!')
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
if task == Task.SINGLE_SEQUENCE:
if datasets.train is None:
......
......@@ -8,7 +8,7 @@ from dataclasses import dataclass
from typing import Optional, Sequence
import torch
from pytorch3d.implicitron.tools.config import enable_get_default_args
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from .dataset_base import FrameData, ImplicitronDatasetBase
from .dataset_map_provider import DatasetMap
......@@ -16,40 +16,51 @@ from .scene_batch_sampler import SceneBatchSampler
@dataclass
class Dataloaders:
class DataLoaderMap:
"""
A provider of dataloaders for implicitron.
A collection of data loaders for Implicitron.
Members:
train: a dataloader for training
val: a dataloader for validating during training
test: a dataloader for final evaluation
train: a data loader for training
val: a data loader for validating during training
test: a data loader for final evaluation
"""
train: Optional[torch.utils.data.DataLoader[FrameData]]
val: Optional[torch.utils.data.DataLoader[FrameData]]
test: Optional[torch.utils.data.DataLoader[FrameData]]
def __getitem__(
self, split: str
) -> Optional[torch.utils.data.DataLoader[FrameData]]:
"""
Get one of the data loaders by key (name of data split)
"""
if split not in ["train", "val", "test"]:
raise ValueError(f"{split} was not a valid split name (train/val/test)")
return getattr(self, split)
def dataloader_zoo(
datasets: DatasetMap,
batch_size: int = 1,
num_workers: int = 0,
dataset_len: int = 1000,
dataset_len_val: int = 1,
images_per_seq_options: Sequence[int] = (2,),
sample_consecutive_frames: bool = False,
consecutive_frames_max_gap: int = 0,
consecutive_frames_max_gap_seconds: float = 0.1,
) -> Dataloaders:
class DataLoaderMapProviderBase(ReplaceableBase):
"""
Provider of a collection of data loaders for a given collection of datasets.
"""
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
"""
Returns a collection of data loaders for a given collection of datasets.
"""
raise NotImplementedError()
@registry.register
class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
"""
Returns a set of dataloaders for a given set of datasets.
The default implementation of DataLoaderMapProviderBase.
Args:
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
batch_size: The size of the batch of the dataloader.
Members:
batch_size: The size of the batch of the data loader.
num_workers: Number data-loading threads.
dataset_len: The number of batches in a training epoch.
dataset_len_val: The number of batches in a validation epoch.
......@@ -69,48 +80,60 @@ def dataloader_zoo(
maximum difference in frame_timestamp of neighbouring frames when forming
connected segments; if both this and consecutive_frames_max_gap are 0s,
the whole sequence is considered a segment regardless of frame timestamps.
Returns:
dataloaders: A dictionary containing the
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
"""
dataloader_kwargs = {"num_workers": num_workers, "collate_fn": FrameData.collate}
def train_or_val_loader(
dataset: Optional[ImplicitronDatasetBase], num_batches: int
) -> Optional[torch.utils.data.DataLoader]:
if dataset is None:
return None
batch_sampler = SceneBatchSampler(
dataset,
batch_size,
num_batches=len(dataset) if num_batches <= 0 else num_batches,
images_per_seq_options=images_per_seq_options,
sample_consecutive_frames=sample_consecutive_frames,
consecutive_frames_max_gap=consecutive_frames_max_gap,
consecutive_frames_max_gap_seconds=consecutive_frames_max_gap_seconds,
batch_size: int = 1
num_workers: int = 0
dataset_len: int = 1000
dataset_len_val: int = 1
images_per_seq_options: Sequence[int] = (2,)
sample_consecutive_frames: bool = False
consecutive_frames_max_gap: int = 0
consecutive_frames_max_gap_seconds: float = 0.1
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
"""
Returns a collection of data loaders for a given collection of datasets.
"""
data_loader_kwargs = {
"num_workers": self.num_workers,
"collate_fn": FrameData.collate,
}
def train_or_val_loader(
dataset: Optional[ImplicitronDatasetBase], num_batches: int
) -> Optional[torch.utils.data.DataLoader]:
if dataset is None:
return None
batch_sampler = SceneBatchSampler(
dataset,
self.batch_size,
num_batches=len(dataset) if num_batches <= 0 else num_batches,
images_per_seq_options=self.images_per_seq_options,
sample_consecutive_frames=self.sample_consecutive_frames,
consecutive_frames_max_gap=self.consecutive_frames_max_gap,
consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds,
)
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
**data_loader_kwargs,
)
train_data_loader = train_or_val_loader(datasets.train, self.dataset_len)
val_data_loader = train_or_val_loader(datasets.val, self.dataset_len_val)
test_dataset = datasets.test
if test_dataset is not None:
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_sampler=test_dataset.get_eval_batches(),
**data_loader_kwargs,
)
else:
test_data_loader = None
return DataLoaderMap(
train=train_data_loader, val=val_data_loader, test=test_data_loader
)
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
**dataloader_kwargs,
)
train_dataloader = train_or_val_loader(datasets.train, dataset_len)
val_dataloader = train_or_val_loader(datasets.val, dataset_len_val)
test_dataset = datasets.test
if test_dataset is not None:
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_sampler=test_dataset.get_eval_batches(),
**dataloader_kwargs,
)
else:
test_dataloader = None
return Dataloaders(train=train_dataloader, val=val_dataloader, test=test_dataloader)
enable_get_default_args(dataloader_zoo)
......@@ -6,15 +6,10 @@
from typing import Tuple
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import (
get_default_args_field,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.implicitron.tools.config import ReplaceableBase, run_auto_creation
from . import json_index_dataset_map_provider # noqa
from .dataloader_zoo import dataloader_zoo, Dataloaders
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
......@@ -24,7 +19,7 @@ class DataSourceBase(ReplaceableBase):
and DataLoader configuration.
"""
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]:
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
raise NotImplementedError()
......@@ -36,18 +31,20 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
Members:
dataset_map_provider_class_type: identifies type for dataset_map_provider.
e.g. JsonIndexDatasetMapProvider for Co3D.
data_loader_map_provider_class_type: identifies type for data_loader_map_provider.
"""
dataset_map_provider: DatasetMapProviderBase
dataset_map_provider_class_type: str
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
data_loader_map_provider: DataLoaderMapProviderBase
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
def __post_init__(self):
run_auto_creation(self)
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]:
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
datasets = self.dataset_map_provider.get_dataset_map()
dataloaders = dataloader_zoo(datasets, **self.dataloader_args)
dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
return datasets, dataloaders
def get_task(self) -> Task:
......
dataset_map_provider_class_type: ???
dataloader_args:
batch_size: 1
num_workers: 0
dataset_len: 1000
dataset_len_val: 1
images_per_seq_options:
- 2
sample_consecutive_frames: false
consecutive_frames_max_gap: 0
consecutive_frames_max_gap_seconds: 0.1
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
dataset_map_provider_JsonIndexDatasetMapProvider_args:
category: ???
task_str: singlesequence
......@@ -31,3 +22,13 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args:
image_height: 800
remove_empty_masks: true
path_manager: null
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1
num_workers: 0
dataset_len: 1000
dataset_len_val: 1
images_per_seq_options:
- 2
sample_consecutive_frames: false
consecutive_frames_max_gap: 0
consecutive_frames_max_gap_seconds: 0.1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment