Commit d0036640 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Minor fixes for RACE finetuning (#818)

Summary:
- remove unnecessary extra spaces in RACE data in preprocessing
- fix finetuning instructions (add `--truncate-sequence` and add `--dropout` params)
- close file handle in SentenceRankingTask
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/818

Differential Revision: D16770055

Pulled By: myleott

fbshipit-source-id: 2c80084e92cdf8692f2ea7e43f7c344c402b9e61
parent 2b68e91f
...@@ -4,38 +4,42 @@ ...@@ -4,38 +4,42 @@
### 2) Preprocess RACE data: ### 2) Preprocess RACE data:
```bash ```bash
python ./examples/roberta/preprocess_RACE.py <input-dir> <extracted-data-dir> python ./examples/roberta/preprocess_RACE.py --input-dir <input-dir> --output-dir <extracted-data-dir>
./examples/roberta/preprocess_RACE.sh <extracted-data-dir> <output-dir> ./examples/roberta/preprocess_RACE.sh <extracted-data-dir> <output-dir>
``` ```
### 3) Fine-tuning on RACE: ### 3) Fine-tuning on RACE:
```bash ```bash
MAX_EPOCHS=5 # epoch number MAX_EPOCH=5 # Number of training epochs.
LR=1e-05 # Peak LR for fixed LR scheduler. LR=1e-05 # Peak LR for fixed LR scheduler.
NUM_CLASSES=4 NUM_CLASSES=4
MAX_SENTENCES=2 # batch size MAX_SENTENCES=1 # Batch size per GPU.
UPDATE_FREQ=8 # Accumulate gradients to simulate training on 8 GPUs.
DATA_DIR=/path/to/race-output-dir
ROBERTA_PATH=/path/to/roberta/model.pt ROBERTA_PATH=/path/to/roberta/model.pt
CUDA_VISIBLE_DEVICES=0 python train.py <race-preprocessed-dir>/ \ CUDA_VISIBLE_DEVICES=0,1 fairseq-train $DATA_DIR \
--restore-file $ROBERTA_PATH \ --restore-file $ROBERTA_PATH \
--max-positions 512 \
--max-sentences $MAX_SENTENCES \
--task sentence_ranking \
--reset-optimizer --reset-dataloader --reset-meters \ --reset-optimizer --reset-dataloader --reset-meters \
--required-batch-size-multiple 1 \ --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
--task sentence_ranking \
--num-classes $NUM_CLASSES \
--init-token 0 --separator-token 2 \ --init-token 0 --separator-token 2 \
--max-option-length 128 \
--max-positions 512 \
--truncate-sequence \
--arch roberta_large \ --arch roberta_large \
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
--criterion sentence_ranking \ --criterion sentence_ranking \
--num-classes $NUM_CLASSES \ --optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
--clip-norm 0.0 \ --clip-norm 0.0 \
--lr-scheduler fixed --lr $LR \ --lr-scheduler fixed --lr $LR \
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
--max-epoch 10 \ --max-sentences $MAX_SENTENCES \
--update-freq 8 \ --required-batch-size-multiple 1 \
--find-unused-parameters \ --update-freq $UPDATE_FREQ \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric; --max-epoch $MAX_EPOCH
``` ```
**Note:** **Note:**
......
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
# #
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import argparse import argparse
import json import json
import os import os
import re
class InputExample: class InputExample:
...@@ -37,6 +39,7 @@ def get_examples(data_dir, set_type): ...@@ -37,6 +39,7 @@ def get_examples(data_dir, set_type):
options = cur_data["options"] options = cur_data["options"]
questions = cur_data["questions"] questions = cur_data["questions"]
context = cur_data["article"].replace("\n", " ") context = cur_data["article"].replace("\n", " ")
context = re.sub(r'\s+', ' ', context)
for i in range(len(answers)): for i in range(len(answers)):
label = ord(answers[i]) - ord("A") label = ord(answers[i]) - ord("A")
qa_list = [] qa_list = []
...@@ -47,6 +50,7 @@ def get_examples(data_dir, set_type): ...@@ -47,6 +50,7 @@ def get_examples(data_dir, set_type):
qa_cat = question.replace("_", option) qa_cat = question.replace("_", option)
else: else:
qa_cat = " ".join([question, option]) qa_cat = " ".join([question, option])
qa_cat = re.sub(r'\s+', ' ', qa_cat)
qa_list.append(qa_cat) qa_list.append(qa_cat)
examples.append(InputExample(context, qa_list, label)) examples.append(InputExample(context, qa_list, label))
...@@ -68,12 +72,15 @@ def main(): ...@@ -68,12 +72,15 @@ def main():
) )
args = parser.parse_args() args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
for set_type in ["train", "dev", "test-middle", "test-high"]: for set_type in ["train", "dev", "test-middle", "test-high"]:
examples = get_examples(args.input_dir, set_type) examples = get_examples(args.input_dir, set_type)
qa_file_paths = [args.output_dir + set_type + ".input" + str(i + 1) for i in range(4)] qa_file_paths = [os.path.join(args.output_dir, set_type + ".input" + str(i + 1)) for i in range(4)]
qa_files = [open(qa_file_path, 'w') for qa_file_path in qa_file_paths] qa_files = [open(qa_file_path, 'w') for qa_file_path in qa_file_paths]
outf_context_path = args.output_dir + set_type + ".input0" outf_context_path = os.path.join(args.output_dir, set_type + ".input0")
outf_label_path = args.output_dir + set_type + ".label" outf_label_path = os.path.join(args.output_dir, set_type + ".label")
outf_context = open(outf_context_path, 'w') outf_context = open(outf_context_path, 'w')
outf_label = open(outf_label_path, 'w') outf_label = open(outf_label_path, 'w')
for example in examples: for example in examples:
......
...@@ -42,7 +42,6 @@ for INPUT_TYPE in $INPUT_TYPES ...@@ -42,7 +42,6 @@ for INPUT_TYPE in $INPUT_TYPES
do do
LANG="input$INPUT_TYPE" LANG="input$INPUT_TYPE"
fairseq-preprocess \ fairseq-preprocess \
--dataset-impl cached \
--only-source \ --only-source \
--trainpref "$RACE_DATA_FOLDER/train.$INPUT_TYPE.bpe" \ --trainpref "$RACE_DATA_FOLDER/train.$INPUT_TYPE.bpe" \
--validpref "$RACE_DATA_FOLDER/dev.$INPUT_TYPE.bpe" \ --validpref "$RACE_DATA_FOLDER/dev.$INPUT_TYPE.bpe" \
......
...@@ -12,6 +12,7 @@ class TruncateDataset(BaseWrapperDataset): ...@@ -12,6 +12,7 @@ class TruncateDataset(BaseWrapperDataset):
def __init__(self, dataset, truncation_length): def __init__(self, dataset, truncation_length):
super().__init__(dataset) super().__init__(dataset)
assert truncation_length is not None
self.truncation_length = truncation_length self.truncation_length = truncation_length
self.dataset = dataset self.dataset = dataset
......
...@@ -39,16 +39,16 @@ class SentenceRankingTask(FairseqTask): ...@@ -39,16 +39,16 @@ class SentenceRankingTask(FairseqTask):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='FILE', parser.add_argument('data', metavar='FILE',
help='file prefix for data') help='file prefix for data')
parser.add_argument('--num-classes', type=int, default=2, parser.add_argument('--num-classes', type=int,
help='number of sentences to be ranked') help='number of sentences to be ranked')
parser.add_argument('--init-token', type=int, default=None, parser.add_argument('--init-token', type=int,
help='add token at the beginning of each batch item') help='add token at the beginning of each batch item')
parser.add_argument('--separator-token', type=int, default=None, parser.add_argument('--separator-token', type=int,
help='add separator token between inputs') help='add separator token between inputs')
parser.add_argument('--no-shuffle', action='store_true', default=False) parser.add_argument('--no-shuffle', action='store_true')
parser.add_argument('--truncate-sequence', action='store_true', default=False, parser.add_argument('--truncate-sequence', action='store_true',
help='Truncate sequence to max_sequence_length') help='Truncate sequence to max_positions')
parser.add_argument('--max-option-length', type=int, default=None, parser.add_argument('--max-option-length', type=int,
help='max length for each option') help='max length for each option')
def __init__(self, args, dictionary): def __init__(self, args, dictionary):
...@@ -71,8 +71,6 @@ class SentenceRankingTask(FairseqTask): ...@@ -71,8 +71,6 @@ class SentenceRankingTask(FairseqTask):
assert args.criterion == 'sentence_ranking', \ assert args.criterion == 'sentence_ranking', \
'Must set --criterion=sentence_ranking' 'Must set --criterion=sentence_ranking'
args.tokens_per_sample = args.max_positions
# load data dictionary # load data dictionary
data_dict = cls.load_dictionary( data_dict = cls.load_dictionary(
args, args,
...@@ -115,6 +113,7 @@ class SentenceRankingTask(FairseqTask): ...@@ -115,6 +113,7 @@ class SentenceRankingTask(FairseqTask):
for input_option in input_options: for input_option in input_options:
if self.args.init_token is not None: if self.args.init_token is not None:
input_option = PrependTokenDataset(input_option, self.args.init_token) input_option = PrependTokenDataset(input_option, self.args.init_token)
if self.args.max_option_length is not None:
input_option = TruncateDataset(input_option, self.args.max_option_length) input_option = TruncateDataset(input_option, self.args.max_option_length)
src_token = ConcatSentencesDataset(input_option, input0) src_token = ConcatSentencesDataset(input_option, input0)
if self.args.truncate_sequence: if self.args.truncate_sequence:
...@@ -145,9 +144,10 @@ class SentenceRankingTask(FairseqTask): ...@@ -145,9 +144,10 @@ class SentenceRankingTask(FairseqTask):
label_path = '{}.label'.format(get_path('label', split)) label_path = '{}.label'.format(get_path('label', split))
if os.path.exists(label_path): if os.path.exists(label_path):
with open(label_path) as h:
dataset.update( dataset.update(
target=RawLabelDataset([ target=RawLabelDataset([
int(x.strip()) for x in open(label_path).readlines() int(x.strip()) for x in h.readlines()
]) ])
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment