Unverified Commit 286dc19a authored by Jonathan Chang's avatar Jonathan Chang Committed by GitHub
Browse files

Fix IterableDataset with __len__ in Trainer (#8095)

parent d93acd6f
......@@ -384,7 +384,9 @@ class Trainer:
dataset.set_format(type=dataset.format["type"], columns=columns)
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if not isinstance(self.train_dataset, collections.abc.Sized):
if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(
self.train_dataset, collections.abc.Sized
):
return None
elif is_torch_tpu_available():
return get_tpu_sampler(self.train_dataset)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment