Commit 3239594f authored by Roman Shapovalov's avatar Roman Shapovalov Committed by Facebook GitHub Bot
Browse files

Fix: Correct concatenation of datasets in train conditioning

Summary: ChainDataset is iterable, and it toes not go along with a custom batch sampler.

Reviewed By: bottler

Differential Revision: D42742315

fbshipit-source-id: 40a715c8d24abe72cb2777634247d7467f628564
parent 11959e0b
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from torch.utils.data import ( from torch.utils.data import (
BatchSampler, BatchSampler,
ChainDataset, ConcatDataset,
DataLoader, DataLoader,
RandomSampler, RandomSampler,
Sampler, Sampler,
...@@ -482,7 +482,7 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase): ...@@ -482,7 +482,7 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
num_batches=num_batches, num_batches=num_batches,
) )
return DataLoader( return DataLoader(
ChainDataset([dataset, train_dataset]), ConcatDataset([dataset, train_dataset]),
batch_sampler=sampler, batch_sampler=sampler,
**data_loader_kwargs, **data_loader_kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment