"sgl-router/vscode:/vscode.git/clone" did not exist on "4ac8e09df0a908c05db4b952ce1b19d705dc92b8"
Unverified Commit 769a9542 authored by peter-sk's avatar peter-sk Committed by GitHub
Browse files

move code to Trainer.evaluate to enable use of that function with multiple datasets (#27844)



* move code to Trainer.evaluate to enable use of that function with multiple datasets

* test

* update doc string

* and a tip

* forgot the type

---------
Co-authored-by: default avatarProf. Peter Schneider-Kamp <jps@ordbogen.com>
parent cd9f9d63
......@@ -2261,16 +2261,6 @@ class Trainer:
metrics = None
if self.control.should_evaluate:
if isinstance(self.eval_dataset, dict):
metrics = {}
for eval_dataset_name, eval_dataset in self.eval_dataset.items():
dataset_metrics = self.evaluate(
eval_dataset=eval_dataset,
ignore_keys=ignore_keys_for_eval,
metric_key_prefix=f"eval_{eval_dataset_name}",
)
metrics.update(dataset_metrics)
else:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)
......@@ -2997,7 +2987,7 @@ class Trainer:
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
......@@ -3010,10 +3000,24 @@ class Trainer:
You can also subclass and override this method to inject custom behavior.
Args:
eval_dataset (`Dataset`, *optional*):
eval_dataset (Union[`Dataset`, Dict[str, `Dataset`]), *optional*):
Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
method.
not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will
evaluate on each dataset, prepending the dictionary key to the metric name. Datasets must implement the
`__len__` method.
<Tip>
If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run
separate evaluations on each dataset. This can be useful to monitor how training affects other
datasets or simply to get a more fine-grained evaluation.
When used with `load_best_model_at_end`, make sure `metric_for_best_model` references exactly one
of the datasets. If you, for example, pass in `{"data1": data1, "data2": data2}` for two datasets
`data1` and `data2`, you could specify `metric_for_best_model="eval_data1_loss"` for using the
loss on `data1` and `metric_for_best_model="eval_data1_loss"` for the loss on `data2`.
</Tip>
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
......@@ -3025,6 +3029,19 @@ class Trainer:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
# handle multipe eval datasets
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if isinstance(eval_dataset, dict):
metrics = {}
for eval_dataset_name, _eval_dataset in eval_dataset.items():
dataset_metrics = self.evaluate(
eval_dataset=_eval_dataset,
ignore_keys=ignore_keys,
metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
)
metrics.update(dataset_metrics)
return metrics
# memory metrics - must set up as early as possible
self._memory_tracker.start()
......
......@@ -103,6 +103,7 @@ if is_torch_available():
import transformers.optimization
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
EarlyStoppingCallback,
GlueDataset,
......@@ -1845,6 +1846,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
result = trainer.evaluate()
self.assertLess(result["eval_loss"], 0.2)
@slow
def test_trainer_eval_multiple(self):
MODEL_ID = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
dataset = LineByLineTextDataset(
tokenizer=tokenizer,
file_path=PATH_SAMPLE_TEXT,
block_size=tokenizer.max_len_single_sentence,
)
for example in dataset.examples:
example["labels"] = example["input_ids"]
training_args = TrainingArguments(
output_dir="./examples",
use_cpu=True,
per_device_eval_batch_size=1,
)
trainer = Trainer(
model=model,
args=training_args,
eval_dataset={
"data1": dataset,
"data2": dataset,
},
)
result = trainer.evaluate()
self.assertIn("eval_data1_loss", result)
self.assertIn("eval_data2_loss", result)
@slow
def test_trainer_eval_lm(self):
MODEL_ID = "distilroberta-base"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment