"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4ce74edf5118ed6490f5ed95b7c89b0190e174ed"
Unverified Commit 4f24058c authored by karthikrangasai's avatar karthikrangasai Committed by GitHub
Browse files

Update Seq2Seq QA example script to use SQuAD metric. (#14335)

* Update postporcessing accordingly to use SQuAD metric.

* Update assets accordingly based on SQuAD metrics.

* Fix function naming error.
parent be4a6c64
...@@ -25,22 +25,20 @@ from dataclasses import dataclass, field ...@@ -25,22 +25,20 @@ from dataclasses import dataclass, field
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import datasets import datasets
import nltk
import numpy as np
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
import transformers import transformers
from trainer_seq2seq_qa import QuestionAnsweringSeq2SeqTrainer
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoTokenizer, AutoTokenizer,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,
HfArgumentParser, HfArgumentParser,
Seq2SeqTrainer,
Seq2SeqTrainingArguments, Seq2SeqTrainingArguments,
set_seed, set_seed,
) )
from transformers.trainer_utils import EvalPrediction, get_last_checkpoint from transformers.trainer_utils import EvalLoopOutput, EvalPrediction, get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -411,7 +409,7 @@ def main(): ...@@ -411,7 +409,7 @@ def main():
) )
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
def preprocess_sqaud_batch( def preprocess_squad_batch(
examples, examples,
question_column: str, question_column: str,
context_column: str, context_column: str,
...@@ -422,14 +420,14 @@ def main(): ...@@ -422,14 +420,14 @@ def main():
answers = examples[answer_column] answers = examples[answer_column]
def generate_input(_question, _context): def generate_input(_question, _context):
return " ".join(["question:", _question, "context:", _context]) return " ".join(["question:", _question.lstrip(), "context:", _context.lstrip()])
inputs = [generate_input(question, context) for question, context in zip(questions, contexts)] inputs = [generate_input(question, context) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets return inputs, targets
def preprocess_function(examples): def preprocess_function(examples):
inputs, targets = preprocess_sqaud_batch(examples, question_column, context_column, answer_column) inputs, targets = preprocess_squad_batch(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(inputs, max_length=max_seq_length, padding=padding, truncation=True) model_inputs = tokenizer(inputs, max_length=max_seq_length, padding=padding, truncation=True)
# Setup the tokenizer for targets # Setup the tokenizer for targets
...@@ -446,6 +444,45 @@ def main(): ...@@ -446,6 +444,45 @@ def main():
model_inputs["labels"] = labels["input_ids"] model_inputs["labels"] = labels["input_ids"]
return model_inputs return model_inputs
# Validation preprocessing
def preprocess_validation_function(examples):
inputs, targets = preprocess_squad_batch(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(
inputs,
max_length=max_seq_length,
padding=padding,
truncation=True,
return_overflowing_tokens=True,
return_offsets_mapping=True,
)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=max_answer_length, padding=padding, truncation=True)
# Since one example might give us several features if it has a long context, we need a map from a feature to
# its corresponding example. This key gives us just that.
sample_mapping = model_inputs.pop("overflow_to_sample_mapping")
# For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
# corresponding example_id and we will store the offset mappings.
model_inputs["example_id"] = []
for i in range(len(model_inputs["input_ids"])):
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = sample_mapping[i]
model_inputs["example_id"].append(examples["id"][sample_index])
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
if training_args.do_train: if training_args.do_train:
if "train" not in raw_datasets: if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset") raise ValueError("--do_train requires a train dataset")
...@@ -477,7 +514,7 @@ def main(): ...@@ -477,7 +514,7 @@ def main():
# Validation Feature Creation # Validation Feature Creation
with training_args.main_process_first(desc="validation dataset map pre-processing"): with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_examples.map( eval_dataset = eval_examples.map(
preprocess_function, preprocess_validation_function,
batched=True, batched=True,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names, remove_columns=column_names,
...@@ -498,7 +535,7 @@ def main(): ...@@ -498,7 +535,7 @@ def main():
# Predict Feature Creation # Predict Feature Creation
with training_args.main_process_first(desc="prediction dataset map pre-processing"): with training_args.main_process_first(desc="prediction dataset map pre-processing"):
predict_dataset = predict_examples.map( predict_dataset = predict_examples.map(
preprocess_function, preprocess_validation_function,
batched=True, batched=True,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names, remove_columns=column_names,
...@@ -518,50 +555,53 @@ def main(): ...@@ -518,50 +555,53 @@ def main():
pad_to_multiple_of=8 if training_args.fp16 else None, pad_to_multiple_of=8 if training_args.fp16 else None,
) )
# Post-processing: metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
def postprocess_text(preds, labels):
preds = [" ".join(pred) for pred in preds]
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
# rougeLSum expects newline after each sentence def compute_metrics(p: EvalPrediction):
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] return metric.compute(predictions=p.predictions, references=p.label_ids)
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
return preds, labels # Post-processing:
def post_processing_function(
metric = load_metric("rouge") examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval"
):
def compute_metrics(eval_preds: EvalPrediction): # Decode the predicted tokens.
preds, labels = eval_preds preds = outputs.predictions
if isinstance(preds, tuple): if isinstance(preds, tuple):
preds = preds[0] preds = preds[0]
decoded_preds = [tokenizer.batch_decode(pred, skip_special_tokens=True) for pred in preds] decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
if data_args.ignore_pad_token_for_loss:
# Replace -100 in the labels as we can't decode them. # Build a map example to its corresponding features.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id) example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)}
predictions = {}
# Some simple post-processing # Let's loop over all the examples!
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) for example_index, example in enumerate(examples):
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) # This is the index of the feature associated to the current example.
# Extract a few results from ROUGE feature_index = feature_per_example[example_index]
result = {key: value.mid.fmeasure * 100 for key, value in result.items()} predictions[example["id"]] = decoded_preds[feature_index]
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] # Format the result to the format the metric expects.
result["gen_len"] = np.mean(prediction_lens) if data_args.version_2_with_negative:
result = {k: round(v, 4) for k, v in result.items()} formatted_predictions = [
return result {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
]
else:
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
references = [{"id": ex["id"], "answers": ex[answer_column]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
# Initialize our Trainer # Initialize our Trainer
trainer = Seq2SeqTrainer( trainer = QuestionAnsweringSeq2SeqTrainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=train_dataset if training_args.do_train else None, train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None, eval_dataset=eval_dataset if training_args.do_eval else None,
eval_examples=eval_examples if training_args.do_eval else None,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
compute_metrics=compute_metrics, compute_metrics=compute_metrics,
post_process_function=post_processing_function,
) )
# Training # Training
......
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A subclass of `Trainer` specific to Question-Answering tasks
"""
from typing import Dict, List, Optional
from torch.utils.data import Dataset
from transformers import Seq2SeqTrainer, is_torch_tpu_available
from transformers.trainer_utils import PredictionOutput
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs):
super().__init__(*args, **kwargs)
self.eval_examples = eval_examples
self.post_process_function = post_process_function
# def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"):
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
eval_examples=None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
max_length: Optional[int] = None,
num_beams: Optional[int] = None,
) -> Dict[str, float]:
self._max_length = max_length if max_length is not None else self.args.generation_max_length
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
eval_dataloader = self.get_eval_dataloader(eval_dataset)
eval_examples = self.eval_examples if eval_examples is None else eval_examples
# Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics
self.compute_metrics = None
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
try:
output = eval_loop(
eval_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
)
finally:
self.compute_metrics = compute_metrics
if self.post_process_function is not None and self.compute_metrics is not None:
eval_preds = self.post_process_function(eval_examples, eval_dataset, output)
metrics = self.compute_metrics(eval_preds)
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
self.log(metrics)
else:
metrics = {}
if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
return metrics
def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"):
predict_dataloader = self.get_test_dataloader(predict_dataset)
# Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics
self.compute_metrics = None
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
try:
output = eval_loop(
predict_dataloader,
description="Prediction",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
)
finally:
self.compute_metrics = compute_metrics
if self.post_process_function is None or self.compute_metrics is None:
return output
predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict")
metrics = self.compute_metrics(predictions)
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
...@@ -274,10 +274,8 @@ class ExamplesTests(TestCasePlus): ...@@ -274,10 +274,8 @@ class ExamplesTests(TestCasePlus):
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_squad_seq2seq.main() run_squad_seq2seq.main()
result = get_results(tmp_dir) result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_rouge1"], 10) self.assertGreaterEqual(result["eval_f1"], 30)
self.assertGreaterEqual(result["eval_rouge2"], 10) self.assertGreaterEqual(result["eval_exact"], 30)
self.assertGreaterEqual(result["eval_rougeL"], 10)
self.assertGreaterEqual(result["eval_rougeLsum"], 10)
def test_run_swag(self): def test_run_swag(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment