Unverified Commit ca136186 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Question Answering for TF trainer (#4320)

* Add QA trainer example for TF

* Make data_dir optional

* Fix parameter logic

* Fix feature convert

* Update the READMEs to add the question-answering task

* Apply style

* Change 'sequence-classification' to 'text-classification' and prefix with 'eval' all the metric names

* Apply style

* Apply style
parent 1e51bb71
...@@ -20,6 +20,7 @@ This is still a work-in-progress – in particular documentation is still sparse ...@@ -20,6 +20,7 @@ This is still a work-in-progress – in particular documentation is still sparse
| [`text-classification`](./text-classification) | GLUE, XNLI | ✅ | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/trainer/01_text_classification.ipynb) | [![Deploy to Azure](https://aka.ms/deploytoazurebutton)](https://portal.azure.com/#create/Microsoft.Template/uri/https%3A%2F%2Fraw.githubusercontent.com%2FAzure%2Fazure-quickstart-templates%2Fmaster%2F101-storage-account-create%2Fazuredeploy.json) | | [`text-classification`](./text-classification) | GLUE, XNLI | ✅ | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/trainer/01_text_classification.ipynb) | [![Deploy to Azure](https://aka.ms/deploytoazurebutton)](https://portal.azure.com/#create/Microsoft.Template/uri/https%3A%2F%2Fraw.githubusercontent.com%2FAzure%2Fazure-quickstart-templates%2Fmaster%2F101-storage-account-create%2Fazuredeploy.json) |
| [`token-classification`](./token-classification) | CoNLL NER | ✅ | ✅ | ✅ | - | - | | [`token-classification`](./token-classification) | CoNLL NER | ✅ | ✅ | ✅ | - | - |
| [`multiple-choice`](./multiple-choice) | SWAG, RACE, ARC | ✅ | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ViktorAlm/notebooks/blob/master/MPC_GPU_Demo_for_TF_and_PT.ipynb) | - | | [`multiple-choice`](./multiple-choice) | SWAG, RACE, ARC | ✅ | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ViktorAlm/notebooks/blob/master/MPC_GPU_Demo_for_TF_and_PT.ipynb) | - |
| [`question-answering`](./question-answering) | SQuAD | - | ✅ | - | - | - |
......
...@@ -157,3 +157,23 @@ Larger batch size may improve the performance while costing more memory. ...@@ -157,3 +157,23 @@ Larger batch size may improve the performance while costing more memory.
} }
``` ```
## SQuAD with the Tensorflow Trainer
```bash
python run_tf_squad.py \
--model_name_or_path bert-base-uncased \
--output_dir model \
--max-seq-length 384 \
--num_train_epochs 2 \
--per_gpu_train_batch_size 8 \
--per_gpu_eval_batch_size 16 \
--do_train \
--logging_dir logs \
--mode question-answering \
--logging_steps 10 \
--learning_rate 3e-5 \
--doc_stride 128 \
--optimizer_name adamw
```
For the moment the evaluation is not available in the Tensorflow Trainer only the training.
\ No newline at end of file
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Fine-tuning the library models for question-answering."""
import logging
import os
from dataclasses import dataclass, field
from typing import Optional
from transformers import (
AutoConfig,
AutoTokenizer,
HfArgumentParser,
TFAutoModelForQuestionAnswering,
TFTrainer,
TFTrainingArguments,
squad_convert_examples_to_features,
)
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."})
# If you want to tweak more attributes on your tokenizer, you should do it in a distinct script,
# or just modify its tokenizer_config.json.
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
data_dir: Optional[str] = field(
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
)
max_seq_length: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
doc_stride: int = field(
default=128,
metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
)
max_query_length: int = field(
default=64,
metadata={
"help": "The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length."
},
)
max_answer_length: int = field(
default=30,
metadata={
"help": "The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
version_2_with_negative: bool = field(
default=False, metadata={"help": "If true, the SQuAD examples contain some that do not have an answer."}
)
null_score_diff_threshold: float = field(
default=0.0, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
)
n_best_size: int = field(
default=20, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
)
lang_id: int = field(
default=0,
metadata={
"help": "language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
},
)
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(
"n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.n_gpu,
bool(training_args.n_gpu > 1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
# Prepare Question-Answering task
# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast,
)
with training_args.strategy.scope():
model = TFAutoModelForQuestionAnswering.from_pretrained(
model_args.model_name_or_path,
from_pt=bool(".bin" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
)
# Get datasets
if not data_args.data_dir:
if data_args.version_2_with_negative:
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically")
try:
import tensorflow_datasets as tfds
except ImportError:
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
tfds_examples = tfds.load("squad")
train_examples = (
SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=False)
if training_args.do_train
else None
)
eval_examples = (
SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=True)
if training_args.do_eval
else None
)
else:
processor = SquadV2Processor() if data_args.version_2_with_negative else SquadV1Processor()
train_examples = processor.get_train_examples(data_args.data_dir) if training_args.do_train else None
eval_examples = processor.get_dev_examples(data_args.data_dir) if training_args.do_eval else None
train_dataset = (
squad_convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
max_seq_length=data_args.max_seq_length,
doc_stride=data_args.doc_stride,
max_query_length=data_args.max_query_length,
is_training=True,
return_dataset="tf",
)
if training_args.do_train
else None
)
eval_dataset = (
squad_convert_examples_to_features(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=data_args.max_seq_length,
doc_stride=data_args.doc_stride,
max_query_length=data_args.max_query_length,
is_training=False,
return_dataset="tf",
)
if training_args.do_eval
else None
)
# Initialize our Trainer
trainer = TFTrainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,)
# Training
if training_args.do_train:
trainer.train()
trainer.save_model()
tokenizer.save_pretrained(training_args.output_dir)
if __name__ == "__main__":
main()
...@@ -226,7 +226,11 @@ class TFTrainer: ...@@ -226,7 +226,11 @@ class TFTrainer:
else: else:
metrics = {} metrics = {}
metrics["loss"] = loss.numpy() metrics["eval_loss"] = loss.numpy()
for key in list(metrics.keys()):
if not key.startswith("eval_"):
metrics[f"eval_{key}"] = metrics.pop(key)
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
...@@ -333,7 +337,7 @@ class TFTrainer: ...@@ -333,7 +337,7 @@ class TFTrainer:
gradients = [(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients] gradients = [(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients]
vars = self.model.trainable_variables vars = self.model.trainable_variables
if self.args.mode == "token-classification": if self.args.mode in ["token-classification", "question-answering"]:
vars = [var for var in self.model.trainable_variables if "pooler" not in var.name] vars = [var for var in self.model.trainable_variables if "pooler" not in var.name]
self.optimizer.apply_gradients(list(zip(gradients, vars))) self.optimizer.apply_gradients(list(zip(gradients, vars)))
...@@ -373,7 +377,7 @@ class TFTrainer: ...@@ -373,7 +377,7 @@ class TFTrainer:
per_example_loss, _ = self._run_model(features, labels, True) per_example_loss, _ = self._run_model(features, labels, True)
vars = self.model.trainable_variables vars = self.model.trainable_variables
if self.args.mode == "token-classification": if self.args.mode in ["token-classification", "question-answering"]:
vars = [var for var in self.model.trainable_variables if "pooler" not in var.name] vars = [var for var in self.model.trainable_variables if "pooler" not in var.name]
gradients = self.optimizer.get_gradients(per_example_loss, vars) gradients = self.optimizer.get_gradients(per_example_loss, vars)
...@@ -390,7 +394,7 @@ class TFTrainer: ...@@ -390,7 +394,7 @@ class TFTrainer:
labels: the batched labels. labels: the batched labels.
training: run the model in training mode or not training: run the model in training mode or not
""" """
if self.args.mode == "sequence-classification" or self.args.mode == "token-classification": if self.args.mode == "text-classification" or self.args.mode == "token-classification":
logits = self.model(features, training=training)[0] logits = self.model(features, training=training)[0]
else: else:
logits = self.model(features, training=training) logits = self.model(features, training=training)
...@@ -400,6 +404,10 @@ class TFTrainer: ...@@ -400,6 +404,10 @@ class TFTrainer:
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
loss = self.loss(labels, reduced_logits) loss = self.loss(labels, reduced_logits)
elif self.args.mode == "question-answering":
start_loss = self.loss(labels["start_position"], logits[0])
end_loss = self.loss(labels["end_position"], logits[1])
loss = (start_loss + end_loss) / 2.0
else: else:
loss = self.loss(labels, logits) loss = self.loss(labels, logits)
......
...@@ -21,8 +21,8 @@ class TFTrainingArguments(TrainingArguments): ...@@ -21,8 +21,8 @@ class TFTrainingArguments(TrainingArguments):
}, },
) )
mode: str = field( mode: str = field(
default="sequence-classification", default="text-classification",
metadata={"help": 'Type of task, one of "sequence-classification", "token-classification" '}, metadata={"help": 'Type of task, one of "text-classification", "token-classification", "question-answering"'},
) )
loss_name: str = field( loss_name: str = field(
default="SparseCategoricalCrossentropy", default="SparseCategoricalCrossentropy",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment