Unverified Commit 44127ec6 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Fix test for auto_find_batch_size on multi-GPU (#27947)

* Fix test for multi-GPU

* WIth CPU handle
parent b911c1f1
......@@ -1558,7 +1558,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
class MockCudaOOMCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
# simulate OOM on the first step
if state.train_batch_size == 16:
if state.train_batch_size >= 16:
raise RuntimeError("CUDA out of memory.")
args = RegressionTrainingArguments(
......@@ -1577,7 +1577,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# We can then make a new Trainer
trainer = Trainer(model, args, train_dataset=train_dataset)
# Check we are at 16 to start
self.assertEqual(trainer._train_batch_size, 16)
self.assertEqual(trainer._train_batch_size, 16 * max(trainer.args.n_gpu, 1))
trainer.train(resume_from_checkpoint=True)
# We should be back to 8 again, picking up based upon the last ran Trainer
self.assertEqual(trainer._train_batch_size, 8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment