Unverified Commit 1b871e09 authored by karthikrangasai's avatar karthikrangasai Committed by GitHub
Browse files

Supporting Seq2Seq model for question answering task (#13432)

* Add seq2seq example for QnA on SQuAD Dataset.

* Changes from review - Fixing styling mistakes.

* Added how to example in README, simplified the access to dataset's preprocess function.

* Added tests for the seq2seq QA example.

* Change dataset column name to fix tests.

* Fix test command mistake.

* Add missing argument 'ignore_pad_token_for_loss' from DataTrainingArguments.

* Add missing argument 'num_beams' from DataTrainingArguments.

* Fix processing of output predicted token ids so that tokenizer decode gets appropriate input. Updated assertion conditions on the tests.
parent 6b83090e
...@@ -57,6 +57,28 @@ f1 = 88.52 ...@@ -57,6 +57,28 @@ f1 = 88.52
exact_match = 81.22 exact_match = 81.22
``` ```
### Fine-tuning T5 on SQuAD2.0
This example code fine-tunes T5 on the SQuAD2.0 dataset.
```bash
python run_seq2seq_qa.py \
--model_name_or_path t5-small \
--dataset_name squad_v2 \
--context_column context \
--question_column question \
--answer_column answer \
--do_train \
--do_eval \
--per_device_train_batch_size 12 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir /tmp/debug_seq2seq_squad/
```
#### Distributed training #### Distributed training
Here is an example using distributed training on 8 V100 GPUs and Bert Whole Word Masking uncased model to reach a F1 > 93 on SQuAD1.1: Here is an example using distributed training on 8 V100 GPUs and Bert Whole Word Masking uncased model to reach a F1 > 93 on SQuAD1.1:
......
This diff is collapsed.
...@@ -57,6 +57,7 @@ if SRC_DIRS is not None: ...@@ -57,6 +57,7 @@ if SRC_DIRS is not None:
import run_mlm import run_mlm
import run_ner import run_ner
import run_qa as run_squad import run_qa as run_squad
import run_seq2seq_qa as run_squad_seq2seq
import run_speech_recognition_ctc import run_speech_recognition_ctc
import run_summarization import run_summarization
import run_swag import run_swag
...@@ -244,6 +245,40 @@ class ExamplesTests(TestCasePlus): ...@@ -244,6 +245,40 @@ class ExamplesTests(TestCasePlus):
self.assertGreaterEqual(result["eval_f1"], 30) self.assertGreaterEqual(result["eval_f1"], 30)
self.assertGreaterEqual(result["eval_exact"], 30) self.assertGreaterEqual(result["eval_exact"], 30)
def test_run_squad_seq2seq(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_seq2seq_qa.py
--model_name_or_path t5-small
--context_column context
--question_column question
--answer_column answers
--version_2_with_negative
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
--output_dir {tmp_dir}
--overwrite_output_dir
--max_steps=10
--warmup_steps=2
--do_train
--do_eval
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--predict_with_generate
""".split()
with patch.object(sys, "argv", testargs):
run_squad_seq2seq.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_rouge1"], 10)
self.assertGreaterEqual(result["eval_rouge2"], 10)
self.assertGreaterEqual(result["eval_rougeL"], 10)
self.assertGreaterEqual(result["eval_rougeLsum"], 10)
def test_run_swag(self): def test_run_swag(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment