Unverified Commit a21ee1f9 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Implement len in IterableDatasetShard (#13780)

parent 83d3dc0f
......@@ -772,6 +772,13 @@ class IterableDatasetShard(IterableDataset):
for i in process_slice:
yield current_batch[i]
def __len__(self):
# Will raise an error if the underlying dataset is not sized.
if self.drop_last:
return len(self.dataset) // self.num_processes
else:
return math.ceil(len(self.dataset) / self.num_processes)
# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
# helper methods here
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment