Unverified Commit 83ef8bca authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix finite IterableDataset test on multiple GPUs (#14445)

parent da36c557
...@@ -1069,13 +1069,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1069,13 +1069,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler) self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
def test_training_finite_iterable_dataset(self): def test_training_finite_iterable_dataset(self):
num_gpus = max(1, get_gpu_count())
if num_gpus > 2:
return
config = RegressionModelConfig() config = RegressionModelConfig()
model = RegressionPreTrainedModel(config) model = RegressionPreTrainedModel(config)
batch_size = 1 batch_size = 1
num_samples = 10 num_samples = 10
available_steps = num_samples // batch_size available_steps = num_samples // (batch_size * num_gpus)
data = FiniteIterableDataset(length=num_samples) data = FiniteIterableDataset(length=num_samples)
train_args = TrainingArguments( train_args = TrainingArguments(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment