Commit b6c55b62 authored by Jingfei Du's avatar Jingfei Du Committed by Facebook Github Bot
Browse files

added sentence ranking task and loss (#809)

Summary:
This task and loss are used for sentence ranking and multiple choice tasks such as RACE
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/809

Reviewed By: myleott

Differential Revision: D16715745

Pulled By: jingfeidu

fbshipit-source-id: cb4d1c7b26ebb3e2382449ba51af5745ef56f30f
parent 838e108a
# Finetuning RoBERTa on RACE tasks
### 1) Download the data from RACE website (http://www.cs.cmu.edu/~glai1/data/race/)
### 2) Preprocess RACE data:
```bash
python ./examples/roberta/preprocess_RACE.py <input-dir> <extracted-data-dir>
./examples/roberta/preprocess_RACE.sh <extracted-data-dir> <output-dir>
```
### 3) Fine-tuning on RACE:
```bash
MAX_EPOCHS=5 # epoch number
LR=1e-05 # Peak LR for fixed LR scheduler.
NUM_CLASSES=4
MAX_SENTENCES=2 # batch size
ROBERTA_PATH=/path/to/roberta/model.pt
CUDA_VISIBLE_DEVICES=0 python train.py <race-preprocessed-dir>/ \
--restore-file $ROBERTA_PATH \
--max-positions 512 \
--max-sentences $MAX_SENTENCES \
--task sentence_ranking \
--reset-optimizer --reset-dataloader --reset-meters \
--required-batch-size-multiple 1 \
--init-token 0 --separator-token 2 \
--arch roberta_large \
--criterion sentence_ranking \
--num-classes $NUM_CLASSES \
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
--clip-norm 0.0 \
--lr-scheduler fixed --lr $LR \
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
--max-epoch 10 \
--update-freq 8 \
--find-unused-parameters \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
```
**Note:**
a) As contexts in RACE are relatively long, we are using smaller batch size per GPU while increasing update-freq to achieve larger effective batch size.
b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`.
c) The setting in above command is based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search.
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import os
class InputExample:
def __init__(self, paragraph, qa_list, label):
self.paragraph = paragraph
self.qa_list = qa_list
self.label = label
def get_examples(data_dir, set_type):
"""
Extract paragraph and question-answer list from each json file
"""
examples = []
levels = ["middle", "high"]
set_type_c = set_type.split('-')
if len(set_type_c) == 2:
levels = [set_type_c[1]]
set_type = set_type_c[0]
for level in levels:
cur_dir = os.path.join(data_dir, set_type, level)
for filename in os.listdir(cur_dir):
cur_path = os.path.join(cur_dir, filename)
with open(cur_path, 'r') as f:
cur_data = json.load(f)
answers = cur_data["answers"]
options = cur_data["options"]
questions = cur_data["questions"]
context = cur_data["article"].replace("\n", " ")
for i in range(len(answers)):
label = ord(answers[i]) - ord("A")
qa_list = []
question = questions[i]
for j in range(4):
option = options[i][j]
if "_" in question:
qa_cat = question.replace("_", option)
else:
qa_cat = " ".join([question, option])
qa_list.append(qa_cat)
examples.append(InputExample(context, qa_list, label))
return examples
def main():
"""
Helper script to extract paragraphs questions and answers from RACE datasets.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--input-dir",
help='input directory for downloaded RACE dataset',
)
parser.add_argument(
"--output-dir",
help='output directory for extracted data',
)
args = parser.parse_args()
for set_type in ["train", "dev", "test-middle", "test-high"]:
examples = get_examples(args.input_dir, set_type)
qa_file_paths = [args.output_dir + set_type + ".input" + str(i + 1) for i in range(4)]
qa_files = [open(qa_file_path, 'w') for qa_file_path in qa_file_paths]
outf_context_path = args.output_dir + set_type + ".input0"
outf_label_path = args.output_dir + set_type + ".label"
outf_context = open(outf_context_path, 'w')
outf_label = open(outf_label_path, 'w')
for example in examples:
outf_context.write(example.paragraph + '\n')
for i in range(4):
qa_files[i].write(example.qa_list[i] + '\n')
outf_label.write(str(example.label) + '\n')
for f in qa_files:
f.close()
outf_label.close()
outf_context.close()
if __name__ == '__main__':
main()
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# data should be downloaded and processed with reprocess_RACE.py
if [[ $# -ne 2 ]]; then
echo "Run as following:"
echo "./examples/roberta/preprocess_RACE.sh <race_data_folder> <output_folder>"
exit 1
fi
RACE_DATA_FOLDER=$1
OUT_DATA_FOLDER=$2
# download bpe encoder.json, vocabulary and fairseq dictionary
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
SPLITS="train dev test-middle test-high"
INPUT_TYPES="input0 input1 input2 input3 input4"
for INPUT_TYPE in $INPUT_TYPES
do
for SPLIT in $SPLITS
do
echo "BPE encoding $SPLIT/$INPUT_TYPE"
python -m examples.roberta.multiprocessing_bpe_encoder \
--encoder-json encoder.json \
--vocab-bpe vocab.bpe \
--inputs "$RACE_DATA_FOLDER/$SPLIT.$INPUT_TYPE" \
--outputs "$RACE_DATA_FOLDER/$SPLIT.$INPUT_TYPE.bpe" \
--workers 10 \
--keep-empty;
done
done
for INPUT_TYPE in $INPUT_TYPES
do
LANG="input$INPUT_TYPE"
fairseq-preprocess \
--dataset-impl cached \
--only-source \
--trainpref "$RACE_DATA_FOLDER/train.$INPUT_TYPE.bpe" \
--validpref "$RACE_DATA_FOLDER/dev.$INPUT_TYPE.bpe" \
--testpref "$RACE_DATA_FOLDER/test-middle.$INPUT_TYPE.bpe,$RACE_DATA_FOLDER/test-high.$INPUT_TYPE.bpe" \
--destdir "$OUT_DATA_FOLDER/$INPUT_TYPE" \
--workers 10 \
--srcdict dict.txt;
done
rm -rf "$OUT_DATA_FOLDER/label"
mkdir -p "$OUT_DATA_FOLDER/label"
cp "$RACE_DATA_FOLDER/train.label" "$OUT_DATA_FOLDER/label/"
cp "$RACE_DATA_FOLDER/dev.label" "$OUT_DATA_FOLDER/label/valid.label"
cp "$RACE_DATA_FOLDER/test-middle.label" "$OUT_DATA_FOLDER/label/test.label"
cp "$RACE_DATA_FOLDER/test-high.label" "$OUT_DATA_FOLDER/label/test1.label"
\ No newline at end of file
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import utils
from . import FairseqCriterion, register_criterion
@register_criterion('sentence_ranking')
class SentenceRankingCriterion(FairseqCriterion):
def forward(self, model, sample, reduce=True):
"""Compute ranking loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
scores = []
for idx in range(self.args.num_classes):
score, _ = model(
**sample['net_input{idx}'.format(idx=idx+1)],
features_only=True,
classification_head_name='sentence_classification_head',
)
scores.append(score)
logits = torch.cat(scores, dim=1)
targets = model.get_targets(sample, [logits]).view(-1)
sample_size = targets.numel()
loss = F.nll_loss(
F.log_softmax(logits, dim=-1, dtype=torch.float32),
targets,
reduction='sum',
)
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample_size,
'sample_size': sample_size,
}
logging_output.update(
ncorrect=(logits.max(dim=1)[1] == targets).sum().item()
)
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
if len(logging_outputs) > 0 and 'ncorrect' in logging_outputs[0]:
ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs)
agg_output.update(accuracy=ncorrect/nsentences)
if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
return agg_output
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
from fairseq.data import (
ConcatSentencesDataset,
data_utils,
Dictionary,
IdDataset,
NestedDictionaryDataset,
NumSamplesDataset,
NumelDataset,
OffsetTokensDataset,
PrependTokenDataset,
RawLabelDataset,
RightPadDataset,
SortDataset,
TruncateDataset,
)
from . import FairseqTask, register_task
@register_task('sentence_ranking')
class SentenceRankingTask(FairseqTask):
"""
Ranking task on multiple sentences.
Args:
dictionary (Dictionary): the dictionary for the input of the task
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='FILE',
help='file prefix for data')
parser.add_argument('--num-classes', type=int, default=2,
help='number of sentences to be ranked')
parser.add_argument('--init-token', type=int, default=None,
help='add token at the beginning of each batch item')
parser.add_argument('--separator-token', type=int, default=None,
help='add separator token between inputs')
parser.add_argument('--no-shuffle', action='store_true', default=False)
parser.add_argument('--truncate-sequence', action='store_true', default=False,
help='Truncate sequence to max_sequence_length')
parser.add_argument('--max-option-length', type=int, default=None,
help='max length for each option')
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
@classmethod
def load_dictionary(cls, args, filename, source=True):
"""Load the dictionary from the filename
Args:
filename (str): the filename
"""
dictionary = Dictionary.load(filename)
dictionary.add_symbol('<mask>')
return dictionary
@classmethod
def setup_task(cls, args, **kwargs):
assert args.criterion == 'sentence_ranking', \
'Must set --criterion=sentence_ranking'
args.tokens_per_sample = args.max_positions
# load data dictionary
data_dict = cls.load_dictionary(
args,
os.path.join(args.data, 'input0', 'dict.txt'),
source=True,
)
print('| [input] dictionary: {} types'.format(len(data_dict)))
return SentenceRankingTask(args, data_dict)
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split (e.g., train, valid, test)."""
def get_path(type, split):
return os.path.join(self.args.data, type, split)
def make_dataset(type, dictionary):
split_path = get_path(type, split)
dataset = data_utils.load_indexed_dataset(
split_path,
self.source_dictionary,
self.args.dataset_impl,
combine=combine,
)
return dataset
input0 = make_dataset('input0', self.source_dictionary)
input_options = [
make_dataset(
'input{idx}'.format(idx=idx + 1),
self.source_dictionary
)
for idx in range(self.args.num_classes)
]
if self.args.separator_token is not None:
input0 = PrependTokenDataset(input0, self.args.separator_token)
src_tokens = []
for input_option in input_options:
if self.args.init_token is not None:
input_option = PrependTokenDataset(input_option, self.args.init_token)
input_option = TruncateDataset(input_option, self.args.max_option_length)
src_token = ConcatSentencesDataset(input_option, input0)
if self.args.truncate_sequence:
src_token = TruncateDataset(src_token, self.args.max_positions)
src_tokens.append(src_token)
with data_utils.numpy_seed(self.args.seed):
shuffle = np.random.permutation(len(src_tokens[0]))
dataset = {
'id': IdDataset(),
'nsentences': NumSamplesDataset(),
'ntokens': NumelDataset(src_tokens[0], reduce=True),
}
for src_token_idx in range(len(src_tokens)):
dataset.update(
{
'net_input{idx}'.format(idx=src_token_idx+1): {
'src_tokens': RightPadDataset(
src_tokens[src_token_idx],
pad_idx=self.source_dictionary.pad(),
),
'src_lengths': NumelDataset(src_tokens[src_token_idx], reduce=False),
}
}
)
label_path = f"{get_path('label', split)}.label"
if os.path.exists(label_path):
dataset.update(
target=RawLabelDataset([
int(x.strip()) for x in open(label_path).readlines()
])
)
nested_dataset = NestedDictionaryDataset(
dataset,
sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])],
)
if self.args.no_shuffle:
dataset = nested_dataset
else:
dataset = SortDataset(
nested_dataset,
# shuffle
sort_order=[shuffle],
)
print("| Loaded {0} with #samples: {1}".format(split, len(dataset)))
self.datasets[split] = dataset
return self.datasets[split]
def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)
model.register_classification_head(
'sentence_classification_head',
num_classes=1,
)
return model
def max_positions(self):
return self.args.max_positions
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment