Unverified Commit 3f6973db authored by Fanli Lin's avatar Fanli Lin Committed by GitHub
Browse files

[tests] use the correct `n_gpu` in...

[tests] use the correct `n_gpu` in `TrainerIntegrationTest::test_train_and_eval_dataloaders` for XPU (#29307)

* fix n_gpu

* fix style
parent 1ba89dc2
...@@ -1029,7 +1029,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1029,7 +1029,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertFalse(is_any_loss_nan_or_inf(log_history_filter)) self.assertFalse(is_any_loss_nan_or_inf(log_history_filter))
def test_train_and_eval_dataloaders(self): def test_train_and_eval_dataloaders(self):
n_gpu = max(1, backend_device_count(torch_device)) if torch_device == "cuda":
n_gpu = max(1, backend_device_count(torch_device))
else:
n_gpu = 1
trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16) trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16)
self.assertEqual(trainer.get_train_dataloader().total_batch_size, 16 * n_gpu) self.assertEqual(trainer.get_train_dataloader().total_batch_size, 16 * n_gpu)
trainer = get_regression_trainer(learning_rate=0.1, per_device_eval_batch_size=16) trainer = get_regression_trainer(learning_rate=0.1, per_device_eval_batch_size=16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment