Unverified Commit 9e346f74 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix number of examples for iterable datasets in multiprocessing (#18856)

* Fix number of examples for iterable datasets in multiprocessing

* Add stronger check
parent 0ab465a5
...@@ -3040,13 +3040,15 @@ class Trainer: ...@@ -3040,13 +3040,15 @@ class Trainer:
num_samples = len(eval_dataset) num_samples = len(eval_dataset)
# The instance check is weird and does not actually check for the type, but whether the dataset has the right # The instance check is weird and does not actually check for the type, but whether the dataset has the right
# methods. Therefore we need to make sure it also has the attribute. # methods. Therefore we need to make sure it also has the attribute.
elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"): elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
num_samples = eval_dataset.num_examples num_samples = eval_dataset.num_examples
else: else:
if has_length(dataloader): if has_length(dataloader):
num_samples = self.num_examples(dataloader) num_samples = self.num_examples(dataloader)
else: # both len(dataloader.dataset) and len(dataloader) fail else: # both len(dataloader.dataset) and len(dataloader) fail
num_samples = observed_num_examples num_samples = observed_num_examples
if num_samples == 0 and observed_num_examples > 0:
num_samples = observed_num_examples
# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
# samplers has been rounded to a multiple of batch_size, so we truncate. # samplers has been rounded to a multiple of batch_size, so we truncate.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment