Commit d0036640 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Minor fixes for RACE finetuning (#818)

Summary:
- remove unnecessary extra spaces in RACE data in preprocessing
- fix finetuning instructions (add `--truncate-sequence` and add `--dropout` params)
- close file handle in SentenceRankingTask
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/818

Differential Revision: D16770055

Pulled By: myleott

fbshipit-source-id: 2c80084e92cdf8692f2ea7e43f7c344c402b9e61
parent 2b68e91f
......@@ -4,38 +4,42 @@
### 2) Preprocess RACE data:
```bash
python ./examples/roberta/preprocess_RACE.py <input-dir> <extracted-data-dir>
python ./examples/roberta/preprocess_RACE.py --input-dir <input-dir> --output-dir <extracted-data-dir>
./examples/roberta/preprocess_RACE.sh <extracted-data-dir> <output-dir>
```
### 3) Fine-tuning on RACE:
```bash
MAX_EPOCHS=5 # epoch number
MAX_EPOCH=5 # Number of training epochs.
LR=1e-05 # Peak LR for fixed LR scheduler.
NUM_CLASSES=4
MAX_SENTENCES=2 # batch size
MAX_SENTENCES=1 # Batch size per GPU.
UPDATE_FREQ=8 # Accumulate gradients to simulate training on 8 GPUs.
DATA_DIR=/path/to/race-output-dir
ROBERTA_PATH=/path/to/roberta/model.pt
CUDA_VISIBLE_DEVICES=0 python train.py <race-preprocessed-dir>/ \
CUDA_VISIBLE_DEVICES=0,1 fairseq-train $DATA_DIR \
--restore-file $ROBERTA_PATH \
--max-positions 512 \
--max-sentences $MAX_SENTENCES \
--task sentence_ranking \
--reset-optimizer --reset-dataloader --reset-meters \
--required-batch-size-multiple 1 \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
--task sentence_ranking \
--num-classes $NUM_CLASSES \
--init-token 0 --separator-token 2 \
--max-option-length 128 \
--max-positions 512 \
--truncate-sequence \
--arch roberta_large \
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
--criterion sentence_ranking \
--num-classes $NUM_CLASSES \
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
--clip-norm 0.0 \
--lr-scheduler fixed --lr $LR \
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
--max-epoch 10 \
--update-freq 8 \
--find-unused-parameters \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
--max-sentences $MAX_SENTENCES \
--required-batch-size-multiple 1 \
--update-freq $UPDATE_FREQ \
--max-epoch $MAX_EPOCH
```
**Note:**
......
......@@ -4,9 +4,11 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import os
import re
class InputExample:
......@@ -37,6 +39,7 @@ def get_examples(data_dir, set_type):
options = cur_data["options"]
questions = cur_data["questions"]
context = cur_data["article"].replace("\n", " ")
context = re.sub(r'\s+', ' ', context)
for i in range(len(answers)):
label = ord(answers[i]) - ord("A")
qa_list = []
......@@ -47,6 +50,7 @@ def get_examples(data_dir, set_type):
qa_cat = question.replace("_", option)
else:
qa_cat = " ".join([question, option])
qa_cat = re.sub(r'\s+', ' ', qa_cat)
qa_list.append(qa_cat)
examples.append(InputExample(context, qa_list, label))
......@@ -68,12 +72,15 @@ def main():
)
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
for set_type in ["train", "dev", "test-middle", "test-high"]:
examples = get_examples(args.input_dir, set_type)
qa_file_paths = [args.output_dir + set_type + ".input" + str(i + 1) for i in range(4)]
qa_file_paths = [os.path.join(args.output_dir, set_type + ".input" + str(i + 1)) for i in range(4)]
qa_files = [open(qa_file_path, 'w') for qa_file_path in qa_file_paths]
outf_context_path = args.output_dir + set_type + ".input0"
outf_label_path = args.output_dir + set_type + ".label"
outf_context_path = os.path.join(args.output_dir, set_type + ".input0")
outf_label_path = os.path.join(args.output_dir, set_type + ".label")
outf_context = open(outf_context_path, 'w')
outf_label = open(outf_label_path, 'w')
for example in examples:
......
......@@ -42,7 +42,6 @@ for INPUT_TYPE in $INPUT_TYPES
do
LANG="input$INPUT_TYPE"
fairseq-preprocess \
--dataset-impl cached \
--only-source \
--trainpref "$RACE_DATA_FOLDER/train.$INPUT_TYPE.bpe" \
--validpref "$RACE_DATA_FOLDER/dev.$INPUT_TYPE.bpe" \
......
......@@ -12,6 +12,7 @@ class TruncateDataset(BaseWrapperDataset):
def __init__(self, dataset, truncation_length):
super().__init__(dataset)
assert truncation_length is not None
self.truncation_length = truncation_length
self.dataset = dataset
......
......@@ -39,16 +39,16 @@ class SentenceRankingTask(FairseqTask):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='FILE',
help='file prefix for data')
parser.add_argument('--num-classes', type=int, default=2,
parser.add_argument('--num-classes', type=int,
help='number of sentences to be ranked')
parser.add_argument('--init-token', type=int, default=None,
parser.add_argument('--init-token', type=int,
help='add token at the beginning of each batch item')
parser.add_argument('--separator-token', type=int, default=None,
parser.add_argument('--separator-token', type=int,
help='add separator token between inputs')
parser.add_argument('--no-shuffle', action='store_true', default=False)
parser.add_argument('--truncate-sequence', action='store_true', default=False,
help='Truncate sequence to max_sequence_length')
parser.add_argument('--max-option-length', type=int, default=None,
parser.add_argument('--no-shuffle', action='store_true')
parser.add_argument('--truncate-sequence', action='store_true',
help='Truncate sequence to max_positions')
parser.add_argument('--max-option-length', type=int,
help='max length for each option')
def __init__(self, args, dictionary):
......@@ -71,8 +71,6 @@ class SentenceRankingTask(FairseqTask):
assert args.criterion == 'sentence_ranking', \
'Must set --criterion=sentence_ranking'
args.tokens_per_sample = args.max_positions
# load data dictionary
data_dict = cls.load_dictionary(
args,
......@@ -115,6 +113,7 @@ class SentenceRankingTask(FairseqTask):
for input_option in input_options:
if self.args.init_token is not None:
input_option = PrependTokenDataset(input_option, self.args.init_token)
if self.args.max_option_length is not None:
input_option = TruncateDataset(input_option, self.args.max_option_length)
src_token = ConcatSentencesDataset(input_option, input0)
if self.args.truncate_sequence:
......@@ -145,9 +144,10 @@ class SentenceRankingTask(FairseqTask):
label_path = '{}.label'.format(get_path('label', split))
if os.path.exists(label_path):
with open(label_path) as h:
dataset.update(
target=RawLabelDataset([
int(x.strip()) for x in open(label_path).readlines()
int(x.strip()) for x in h.readlines()
])
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment