"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f9cc333805c47665c6afee8b5867931e54abe0c6"
Unverified Commit f51188cb authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[examples/run_s2s] remove task_specific_params and update rouge computation (#10133)

* fix rouge metrics and task specific params

* fix typo

* round metrics

* typo

* remove task_specific_params
parent 31245775
......@@ -25,10 +25,12 @@ import sys
from dataclasses import dataclass, field
from typing import Optional
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import load_dataset, load_metric
import transformers
from filelock import FileLock
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
......@@ -44,6 +46,10 @@ from transformers import (
from transformers.trainer_utils import get_last_checkpoint, is_main_process
with FileLock(".lock") as lock:
nltk.download("punkt", quiet=True)
logger = logging.getLogger(__name__)
......@@ -110,10 +116,22 @@ class DataTrainingArguments:
default=None,
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
)
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
metadata={
"help": "An optional input evaluation data file to evaluate the metrics (rouge/sacreblue) on "
"(a jsonlines or csv file)."
},
)
test_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input test data file to evaluate the metrics (rouge/sacreblue) on "
"(a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
......@@ -298,6 +316,9 @@ def main():
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
......@@ -335,15 +356,7 @@ def main():
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
# Get the default prefix if None is passed.
if data_args.source_prefix is None:
task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
prefix = task_specific_params.get("prefix", "")
else:
prefix = ""
else:
prefix = data_args.source_prefix
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
......@@ -487,6 +500,19 @@ def main():
metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu"
metric = load_metric(metric_name)
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
# rougeLSum expects newline after each sentence
if metric_name == "rouge":
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
else: # sacrebleu
labels = [[label] for label in labels]
return preds, labels
def compute_metrics(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
......@@ -498,22 +524,19 @@ def main():
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing
decoded_preds = [pred.strip() for pred in decoded_preds]
decoded_labels = [label.strip() for label in decoded_labels]
if metric_name == "sacrebleu":
decoded_labels = [[label] for label in decoded_labels]
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
# Extract a few results from ROUGE
if metric_name == "rouge":
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
# Extract a few results from ROUGE
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
else:
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
result = {"bleu": result["score"]}
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)
result = {k: round(v, 4) for k, v in result.items()}
return result
# Initialize our Trainer
......@@ -555,6 +578,7 @@ def main():
logger.info("*** Evaluate ***")
results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams)
results = {k: round(v, 4) for k, v in results.items()}
output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt")
if trainer.is_world_process_zero():
......@@ -574,6 +598,7 @@ def main():
num_beams=data_args.num_beams,
)
test_metrics = test_results.metrics
test_metrics["test_loss"] = round(test_metrics["test_loss"], 4)
output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt")
if trainer.is_world_process_zero():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment