Commit 0f12c516 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

data_loader_map_provider

Summary: replace dataloader_zoo with a pluggable DataLoaderMapProvider.

Reviewed By: shapovalov

Differential Revision: D36475441

fbshipit-source-id: d16abb190d876940434329928f2e3f2794a25416
parent 79c61a2d
...@@ -237,7 +237,7 @@ generic_model_args: GenericModel ...@@ -237,7 +237,7 @@ generic_model_args: GenericModel
solver_args: init_optimizer solver_args: init_optimizer
data_source_args: ImplicitronDataSource data_source_args: ImplicitronDataSource
└-- dataset_map_provider_*_args └-- dataset_map_provider_*_args
└-- dataloader_args └-- data_loader_map_provider_*_args
``` ```
Please look at the annotations of the respective classes or functions for the lists of hyperparameters. Please look at the annotations of the respective classes or functions for the lists of hyperparameters.
......
...@@ -7,7 +7,7 @@ visualize_interval: 0 ...@@ -7,7 +7,7 @@ visualize_interval: 0
visdom_port: 8097 visdom_port: 8097
data_source_args: data_source_args:
dataset_provider_class_type: JsonIndexDatasetMapProvider dataset_provider_class_type: JsonIndexDatasetMapProvider
dataloader_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 10 batch_size: 10
dataset_len: 1000 dataset_len: 1000
dataset_len_val: 1 dataset_len_val: 1
......
...@@ -2,7 +2,7 @@ defaults: ...@@ -2,7 +2,7 @@ defaults:
- repro_base.yaml - repro_base.yaml
- _self_ - _self_
data_source_args: data_source_args:
dataloader_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 10 batch_size: 10
dataset_len: 1000 dataset_len: 1000
dataset_len_val: 1 dataset_len_val: 1
......
...@@ -2,7 +2,7 @@ defaults: ...@@ -2,7 +2,7 @@ defaults:
- repro_base - repro_base
- _self_ - _self_
data_source_args: data_source_args:
dataloader_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1 batch_size: 1
dataset_len: 1000 dataset_len: 1000
dataset_len_val: 1 dataset_len_val: 1
......
...@@ -2,7 +2,7 @@ defaults: ...@@ -2,7 +2,7 @@ defaults:
- repro_singleseq_base - repro_singleseq_base
- _self_ - _self_
data_source_args: data_source_args:
dataloader_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 10 batch_size: 10
dataset_len: 1000 dataset_len: 1000
dataset_len_val: 1 dataset_len_val: 1
......
...@@ -64,8 +64,8 @@ import tqdm ...@@ -64,8 +64,8 @@ import tqdm
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from packaging import version from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders
from pytorch3d.implicitron.dataset.dataset_base import FrameData from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
...@@ -553,7 +553,7 @@ def _eval_and_dump( ...@@ -553,7 +553,7 @@ def _eval_and_dump(
cfg, cfg,
task: Task, task: Task,
datasets: DatasetMap, datasets: DatasetMap,
dataloaders: Dataloaders, dataloaders: DataLoaderMap,
model, model,
stats, stats,
device, device,
...@@ -566,7 +566,7 @@ def _eval_and_dump( ...@@ -566,7 +566,7 @@ def _eval_and_dump(
dataloader = dataloaders.test dataloader = dataloaders.test
if dataloader is None: if dataloader is None:
raise ValueError('Dataloaders have to contain the "test" entry for eval!') raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
if task == Task.SINGLE_SEQUENCE: if task == Task.SINGLE_SEQUENCE:
if datasets.train is None: if datasets.train is None:
......
...@@ -8,7 +8,7 @@ from dataclasses import dataclass ...@@ -8,7 +8,7 @@ from dataclasses import dataclass
from typing import Optional, Sequence from typing import Optional, Sequence
import torch import torch
from pytorch3d.implicitron.tools.config import enable_get_default_args from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from .dataset_base import FrameData, ImplicitronDatasetBase from .dataset_base import FrameData, ImplicitronDatasetBase
from .dataset_map_provider import DatasetMap from .dataset_map_provider import DatasetMap
...@@ -16,40 +16,51 @@ from .scene_batch_sampler import SceneBatchSampler ...@@ -16,40 +16,51 @@ from .scene_batch_sampler import SceneBatchSampler
@dataclass @dataclass
class Dataloaders: class DataLoaderMap:
""" """
A provider of dataloaders for implicitron. A collection of data loaders for Implicitron.
Members: Members:
train: a dataloader for training train: a data loader for training
val: a dataloader for validating during training val: a data loader for validating during training
test: a dataloader for final evaluation test: a data loader for final evaluation
""" """
train: Optional[torch.utils.data.DataLoader[FrameData]] train: Optional[torch.utils.data.DataLoader[FrameData]]
val: Optional[torch.utils.data.DataLoader[FrameData]] val: Optional[torch.utils.data.DataLoader[FrameData]]
test: Optional[torch.utils.data.DataLoader[FrameData]] test: Optional[torch.utils.data.DataLoader[FrameData]]
def __getitem__(
self, split: str
) -> Optional[torch.utils.data.DataLoader[FrameData]]:
"""
Get one of the data loaders by key (name of data split)
"""
if split not in ["train", "val", "test"]:
raise ValueError(f"{split} was not a valid split name (train/val/test)")
return getattr(self, split)
def dataloader_zoo(
datasets: DatasetMap, class DataLoaderMapProviderBase(ReplaceableBase):
batch_size: int = 1, """
num_workers: int = 0, Provider of a collection of data loaders for a given collection of datasets.
dataset_len: int = 1000, """
dataset_len_val: int = 1,
images_per_seq_options: Sequence[int] = (2,), def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
sample_consecutive_frames: bool = False, """
consecutive_frames_max_gap: int = 0, Returns a collection of data loaders for a given collection of datasets.
consecutive_frames_max_gap_seconds: float = 0.1, """
) -> Dataloaders: raise NotImplementedError()
@registry.register
class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
""" """
Returns a set of dataloaders for a given set of datasets. The default implementation of DataLoaderMapProviderBase.
Args: Members:
datasets: A dictionary containing the batch_size: The size of the batch of the data loader.
`"dataset_subset_name": torch_dataset_object` key, value pairs.
batch_size: The size of the batch of the dataloader.
num_workers: Number data-loading threads. num_workers: Number data-loading threads.
dataset_len: The number of batches in a training epoch. dataset_len: The number of batches in a training epoch.
dataset_len_val: The number of batches in a validation epoch. dataset_len_val: The number of batches in a validation epoch.
...@@ -69,48 +80,60 @@ def dataloader_zoo( ...@@ -69,48 +80,60 @@ def dataloader_zoo(
maximum difference in frame_timestamp of neighbouring frames when forming maximum difference in frame_timestamp of neighbouring frames when forming
connected segments; if both this and consecutive_frames_max_gap are 0s, connected segments; if both this and consecutive_frames_max_gap are 0s,
the whole sequence is considered a segment regardless of frame timestamps. the whole sequence is considered a segment regardless of frame timestamps.
Returns:
dataloaders: A dictionary containing the
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
""" """
dataloader_kwargs = {"num_workers": num_workers, "collate_fn": FrameData.collate} batch_size: int = 1
num_workers: int = 0
def train_or_val_loader( dataset_len: int = 1000
dataset: Optional[ImplicitronDatasetBase], num_batches: int dataset_len_val: int = 1
) -> Optional[torch.utils.data.DataLoader]: images_per_seq_options: Sequence[int] = (2,)
if dataset is None: sample_consecutive_frames: bool = False
return None consecutive_frames_max_gap: int = 0
batch_sampler = SceneBatchSampler( consecutive_frames_max_gap_seconds: float = 0.1
dataset,
batch_size, def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
num_batches=len(dataset) if num_batches <= 0 else num_batches, """
images_per_seq_options=images_per_seq_options, Returns a collection of data loaders for a given collection of datasets.
sample_consecutive_frames=sample_consecutive_frames, """
consecutive_frames_max_gap=consecutive_frames_max_gap,
consecutive_frames_max_gap_seconds=consecutive_frames_max_gap_seconds, data_loader_kwargs = {
"num_workers": self.num_workers,
"collate_fn": FrameData.collate,
}
def train_or_val_loader(
dataset: Optional[ImplicitronDatasetBase], num_batches: int
) -> Optional[torch.utils.data.DataLoader]:
if dataset is None:
return None
batch_sampler = SceneBatchSampler(
dataset,
self.batch_size,
num_batches=len(dataset) if num_batches <= 0 else num_batches,
images_per_seq_options=self.images_per_seq_options,
sample_consecutive_frames=self.sample_consecutive_frames,
consecutive_frames_max_gap=self.consecutive_frames_max_gap,
consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds,
)
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
**data_loader_kwargs,
)
train_data_loader = train_or_val_loader(datasets.train, self.dataset_len)
val_data_loader = train_or_val_loader(datasets.val, self.dataset_len_val)
test_dataset = datasets.test
if test_dataset is not None:
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_sampler=test_dataset.get_eval_batches(),
**data_loader_kwargs,
)
else:
test_data_loader = None
return DataLoaderMap(
train=train_data_loader, val=val_data_loader, test=test_data_loader
) )
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
**dataloader_kwargs,
)
train_dataloader = train_or_val_loader(datasets.train, dataset_len)
val_dataloader = train_or_val_loader(datasets.val, dataset_len_val)
test_dataset = datasets.test
if test_dataset is not None:
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_sampler=test_dataset.get_eval_batches(),
**dataloader_kwargs,
)
else:
test_dataloader = None
return Dataloaders(train=train_dataloader, val=val_dataloader, test=test_dataloader)
enable_get_default_args(dataloader_zoo)
...@@ -6,15 +6,10 @@ ...@@ -6,15 +6,10 @@
from typing import Tuple from typing import Tuple
from omegaconf import DictConfig from pytorch3d.implicitron.tools.config import ReplaceableBase, run_auto_creation
from pytorch3d.implicitron.tools.config import (
get_default_args_field,
ReplaceableBase,
run_auto_creation,
)
from . import json_index_dataset_map_provider # noqa from . import json_index_dataset_map_provider # noqa
from .dataloader_zoo import dataloader_zoo, Dataloaders from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
...@@ -24,7 +19,7 @@ class DataSourceBase(ReplaceableBase): ...@@ -24,7 +19,7 @@ class DataSourceBase(ReplaceableBase):
and DataLoader configuration. and DataLoader configuration.
""" """
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]: def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
raise NotImplementedError() raise NotImplementedError()
...@@ -36,18 +31,20 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13] ...@@ -36,18 +31,20 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
Members: Members:
dataset_map_provider_class_type: identifies type for dataset_map_provider. dataset_map_provider_class_type: identifies type for dataset_map_provider.
e.g. JsonIndexDatasetMapProvider for Co3D. e.g. JsonIndexDatasetMapProvider for Co3D.
data_loader_map_provider_class_type: identifies type for data_loader_map_provider.
""" """
dataset_map_provider: DatasetMapProviderBase dataset_map_provider: DatasetMapProviderBase
dataset_map_provider_class_type: str dataset_map_provider_class_type: str
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo) data_loader_map_provider: DataLoaderMapProviderBase
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
def __post_init__(self): def __post_init__(self):
run_auto_creation(self) run_auto_creation(self)
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]: def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
datasets = self.dataset_map_provider.get_dataset_map() datasets = self.dataset_map_provider.get_dataset_map()
dataloaders = dataloader_zoo(datasets, **self.dataloader_args) dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
return datasets, dataloaders return datasets, dataloaders
def get_task(self) -> Task: def get_task(self) -> Task:
......
dataset_map_provider_class_type: ??? dataset_map_provider_class_type: ???
dataloader_args: data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
batch_size: 1
num_workers: 0
dataset_len: 1000
dataset_len_val: 1
images_per_seq_options:
- 2
sample_consecutive_frames: false
consecutive_frames_max_gap: 0
consecutive_frames_max_gap_seconds: 0.1
dataset_map_provider_JsonIndexDatasetMapProvider_args: dataset_map_provider_JsonIndexDatasetMapProvider_args:
category: ??? category: ???
task_str: singlesequence task_str: singlesequence
...@@ -31,3 +22,13 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args: ...@@ -31,3 +22,13 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args:
image_height: 800 image_height: 800
remove_empty_masks: true remove_empty_masks: true
path_manager: null path_manager: null
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1
num_workers: 0
dataset_len: 1000
dataset_len_val: 1
images_per_seq_options:
- 2
sample_consecutive_frames: false
consecutive_frames_max_gap: 0
consecutive_frames_max_gap_seconds: 0.1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment