Unverified Commit 4a0b958d authored by Tatsuki Okada's avatar Tatsuki Okada Committed by GitHub
Browse files

Fix trainer seq2seq qa.py evaluate log and ft script (#19208)

* fix args option

* fix trainer eval log

* fix out of memory qa script

* do isort, black, flake

* fix tokenize target

* take it back.

* fix: comment
parent 9c6aeba3
...@@ -327,21 +327,28 @@ def main(): ...@@ -327,21 +327,28 @@ def main():
if data_args.dataset_name is not None: if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub. # Downloading and loading a dataset from the hub.
raw_datasets = load_dataset( raw_datasets = load_dataset(
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
) )
else: else:
data_files = {} data_files = {}
if data_args.train_file is not None: if data_args.train_file is not None:
data_files["train"] = data_args.train_file data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1] extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None: if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1] extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None: if data_args.test_file is not None:
data_files["test"] = data_args.test_file data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1] extension = data_args.test_file.split(".")[-1]
raw_datasets = load_dataset(extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir) raw_datasets = load_dataset(
extension,
data_files=data_files,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html. # https://huggingface.co/docs/datasets/loading_datasets.html.
...@@ -359,7 +366,7 @@ def main(): ...@@ -359,7 +366,7 @@ def main():
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_fast=True, use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision, revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
...@@ -476,9 +483,10 @@ def main(): ...@@ -476,9 +483,10 @@ def main():
max_length=max_seq_length, max_length=max_seq_length,
padding=padding, padding=padding,
truncation=True, truncation=True,
return_overflowing_tokens=True,
return_offsets_mapping=True, return_offsets_mapping=True,
return_overflowing_tokens=True,
) )
# Tokenize targets with the `text_target` keyword argument # Tokenize targets with the `text_target` keyword argument
labels = tokenizer(text_target=targets, max_length=max_answer_length, padding=padding, truncation=True) labels = tokenizer(text_target=targets, max_length=max_answer_length, padding=padding, truncation=True)
...@@ -503,6 +511,7 @@ def main(): ...@@ -503,6 +511,7 @@ def main():
] ]
model_inputs["labels"] = labels["input_ids"] model_inputs["labels"] = labels["input_ids"]
return model_inputs return model_inputs
if training_args.do_train: if training_args.do_train:
...@@ -627,7 +636,7 @@ def main(): ...@@ -627,7 +636,7 @@ def main():
eval_examples=eval_examples if training_args.do_eval else None, eval_examples=eval_examples if training_args.do_eval else None,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
compute_metrics=compute_metrics, compute_metrics=compute_metrics if training_args.predict_with_generate else None,
post_process_function=post_processing_function, post_process_function=post_processing_function,
) )
......
...@@ -15,12 +15,14 @@ ...@@ -15,12 +15,14 @@
""" """
A subclass of `Trainer` specific to Question-Answering tasks A subclass of `Trainer` specific to Question-Answering tasks
""" """
import math
import time
from typing import Dict, List, Optional from typing import Dict, List, Optional
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import Seq2SeqTrainer, is_torch_tpu_available from transformers import Seq2SeqTrainer, is_torch_tpu_available
from transformers.trainer_utils import PredictionOutput from transformers.trainer_utils import PredictionOutput, speed_metrics
if is_torch_tpu_available(check_device=False): if is_torch_tpu_available(check_device=False):
...@@ -59,6 +61,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -59,6 +61,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
# Temporarily disable metric computation, we will do it in the loop here. # Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics compute_metrics = self.compute_metrics
self.compute_metrics = None self.compute_metrics = None
start_time = time.time()
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
try: try:
output = eval_loop( output = eval_loop(
...@@ -71,6 +74,15 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -71,6 +74,15 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
) )
finally: finally:
self.compute_metrics = compute_metrics self.compute_metrics = compute_metrics
total_batch_size = self.args.eval_batch_size * self.args.world_size
output.metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=output.num_samples,
num_steps=math.ceil(output.num_samples / total_batch_size),
)
)
if self.post_process_function is not None and self.compute_metrics is not None: if self.post_process_function is not None and self.compute_metrics is not None:
eval_preds = self.post_process_function(eval_examples, eval_dataset, output) eval_preds = self.post_process_function(eval_examples, eval_dataset, output)
...@@ -81,15 +93,15 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -81,15 +93,15 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
if not key.startswith(f"{metric_key_prefix}_"): if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
output.metrics.update(metrics)
self.log(metrics) self.log(metrics)
else:
metrics = {}
if self.args.tpu_metrics_debug or self.args.debug: if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report()) xm.master_print(met.metrics_report())
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
return metrics return metrics
def predict( def predict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment