"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1d1d5bec1b30fc4fa437cba5f1bcc4bf09cc1f4c"
Unverified Commit 03056730 authored by Maxwell Forbes's avatar Maxwell Forbes Committed by GitHub
Browse files

Fall back to `observed_batch_size` when the `dataloader` does not know the `batch_size`. (#13188)

parent ce6add8e
...@@ -2203,6 +2203,9 @@ class Trainer: ...@@ -2203,6 +2203,9 @@ class Trainer:
observed_batch_size = find_batch_size(inputs) observed_batch_size = find_batch_size(inputs)
if observed_batch_size is not None: if observed_batch_size is not None:
observed_num_examples += observed_batch_size observed_num_examples += observed_batch_size
# For batch samplers, batch_size is not known by the dataloader in advance.
if batch_size is None:
batch_size = observed_batch_size
# Prediction step # Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment