Unverified Commit 03056730 authored by Maxwell Forbes's avatar Maxwell Forbes Committed by GitHub
Browse files

Fall back to `observed_batch_size` when the `dataloader` does not know the `batch_size`. (#13188)

parent ce6add8e
......@@ -2203,6 +2203,9 @@ class Trainer:
observed_batch_size = find_batch_size(inputs)
if observed_batch_size is not None:
observed_num_examples += observed_batch_size
# For batch samplers, batch_size is not known by the dataloader in advance.
if batch_size is None:
batch_size = observed_batch_size
# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment