Unverified Commit 485fd814 authored by Bastien Le Chenadec's avatar Bastien Le Chenadec Committed by GitHub
Browse files

Support multiple validation datasets when `dataloader_persistent_workers=True` (#30627)

* Support multiple validation datasets when dataloader_persistent_workers=True

* Test support of multiple validation datasets
parent 147c404f
......@@ -919,25 +919,36 @@ class Trainer:
else:
return None
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
"""
Returns the evaluation [`~torch.utils.data.DataLoader`].
Subclass and override this method if you want to inject some custom behavior.
Args:
eval_dataset (`torch.utils.data.Dataset`, *optional*):
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
by the `model.forward()` method are automatically removed. It must implement `__len__`.
eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
# If we have persistent workers, don't do a fork bomb especially as eval datasets
# don't change during training
if hasattr(self, "_eval_dataloader") and self.args.dataloader_persistent_workers:
return self.accelerator.prepare(self._eval_dataloader)
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
if (
hasattr(self, "_eval_dataloaders")
and dataloader_key in self._eval_dataloaders
and self.args.dataloader_persistent_workers
):
return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
eval_dataset = (
self.eval_dataset[eval_dataset]
if isinstance(eval_dataset, str)
else eval_dataset
if eval_dataset is not None
else self.eval_dataset
)
data_collator = self.data_collator
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
......@@ -962,7 +973,10 @@ class Trainer:
# we need to store the non-prepared version
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
if self.args.dataloader_persistent_workers:
self._eval_dataloader = eval_dataloader
if hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders[dataloader_key] = eval_dataloader
else:
self._eval_dataloaders = {dataloader_key: eval_dataloader}
return self.accelerator.prepare(eval_dataloader)
......@@ -3584,12 +3598,13 @@ class Trainer:
dictionary also contains the epoch number which comes from the training state.
"""
# handle multipe eval datasets
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
override = eval_dataset is not None
eval_dataset = eval_dataset if override else self.eval_dataset
if isinstance(eval_dataset, dict):
metrics = {}
for eval_dataset_name, _eval_dataset in eval_dataset.items():
dataset_metrics = self.evaluate(
eval_dataset=_eval_dataset,
eval_dataset=_eval_dataset if override else eval_dataset_name,
ignore_keys=ignore_keys,
metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
)
......
......@@ -1231,6 +1231,97 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train()
trainer.evaluate()
def test_get_eval_dataloader_without_persistent_workers(self):
train_dataset = RegressionDataset()
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
args = TrainingArguments("./test", report_to="none", dataloader_persistent_workers=False)
# Single evaluation dataset
eval_dataset = RegressionDataset()
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
# Mocking the prepare method to avoid the dataloader changing with each call to get_eval_dataloader
trainer.accelerator.prepare = lambda x: x
default_dataloader = trainer.get_eval_dataloader()
dataloader_with_dataset = trainer.get_eval_dataloader(eval_dataset)
self.assertEqual(default_dataloader.dataset, eval_dataset)
self.assertEqual(dataloader_with_dataset.dataset, eval_dataset)
self.assertNotEqual(default_dataloader, dataloader_with_dataset)
# Multiple evaluation datasets
first_dataset = RegressionDataset()
second_dataset = RegressionDataset()
trainer = Trainer(
tiny_gpt2,
args,
train_dataset=train_dataset,
eval_dataset={"first": first_dataset, "second": second_dataset},
)
# Mocking the prepare method to avoid the dataloader changing with each call to get_eval_dataloader
trainer.accelerator.prepare = lambda x: x
first_dataloader = trainer.get_eval_dataloader("first")
first_dataloader_repeated = trainer.get_eval_dataloader("first")
second_dataloader = trainer.get_eval_dataloader("second")
second_dataloader_repeated = trainer.get_eval_dataloader("second")
self.assertEqual(first_dataset, first_dataloader.dataset)
self.assertEqual(first_dataloader.dataset, first_dataloader_repeated.dataset)
self.assertEqual(second_dataset, second_dataloader.dataset)
self.assertEqual(second_dataloader.dataset, second_dataloader_repeated.dataset)
self.assertNotEqual(first_dataloader, first_dataloader_repeated)
self.assertNotEqual(second_dataloader, second_dataloader_repeated)
def test_get_eval_dataloader_with_persistent_workers(self):
train_dataset = RegressionDataset()
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
args = TrainingArguments(
"./test",
report_to="none",
dataloader_persistent_workers=True,
dataloader_num_workers=2,
)
# Single evaluation dataset
eval_dataset = RegressionDataset()
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
# Mocking the prepare method to avoid the dataloader changing with each call to get_eval_dataloader
trainer.accelerator.prepare = lambda x: x
default_dataloader = trainer.get_eval_dataloader()
dataloader_with_dataset = trainer.get_eval_dataloader(eval_dataset)
self.assertEqual(default_dataloader.dataset, eval_dataset)
self.assertEqual(dataloader_with_dataset.dataset, eval_dataset)
self.assertEqual(default_dataloader, dataloader_with_dataset)
# Multiple evaluation datasets
first_dataset = RegressionDataset()
second_dataset = RegressionDataset()
trainer = Trainer(
tiny_gpt2,
args,
train_dataset=train_dataset,
eval_dataset={"first": first_dataset, "second": second_dataset},
)
# Mocking the prepare method to avoid the dataloader changing with each call to get_eval_dataloader
trainer.accelerator.prepare = lambda x: x
first_dataloader = trainer.get_eval_dataloader("first")
first_dataloader_repeated = trainer.get_eval_dataloader("first")
second_dataloader = trainer.get_eval_dataloader("second")
second_dataloader_repeated = trainer.get_eval_dataloader("second")
self.assertEqual(first_dataset, first_dataloader.dataset)
self.assertEqual(first_dataloader.dataset, first_dataloader_repeated.dataset)
self.assertEqual(second_dataset, second_dataloader.dataset)
self.assertEqual(second_dataloader.dataset, second_dataloader_repeated.dataset)
self.assertEqual(first_dataloader, first_dataloader_repeated)
self.assertEqual(second_dataloader, second_dataloader_repeated)
@require_lomo
@require_torch_gpu
def test_lomo(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment