Commit 138dc8e4 authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

adding glue data preprocessing scripts (#771)

Summary:
1) Added glue data pre-processing script.
2) updated README with usage.

TODO:
1) releasing fairseq dictionary and remove hardcoded path.
2) remove hard-coded path for bpe-encoding,

myleott what do you recommend for above TODOs?
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/771

Reviewed By: myleott

Differential Revision: D16547679

Pulled By: myleott

fbshipit-source-id: 6a6562d9b6215523d048fdf3daee63ffac21e231
parent 33597e5a
......@@ -134,9 +134,77 @@ print('| Accuracy: ', float(ncorrect)/float(nsamples))
# Expected output: 0.9060
```
## Finetuning on GLUE tasks
A more detailed tutorial is coming soon.
##### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands:
```
$ wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
$ python download_glue_data.py --data_dir glue_data --tasks all
```
##### 2) Preprocess GLUE task data:
```
$ ./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>
```
`glue_task_name` is one of the following:
`{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}`
Use `ALL` for preprocessing all the glue tasks.
##### 3) Fine-tuning on GLUE task :
Example fine-tuning cmd for `RTE` task
```
TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16
WARMUP_UPDATES=122 # 6 percent of the number of updates
LR=2e-05 # Peak LR for polynomial LR scheduler.
NUM_CLASSES=2
MAX_SENTENCES=16 # Batch size.
CUDA_VISIBLE_DEVICES=0 python train.py RTE-bin/ \
--restore-file <roberta_large_absolute_path> \
--max-positions 512 \
--max-sentences $MAX_SENTENCES \
--max-tokens 4400 \
--task sentence_prediction \
--reset-optimizer --reset-dataloader --reset-meters \
--required-batch-size-multiple 1 \
--init-token 0 --separator-token 2 \
--arch roberta_large \
--criterion sentence_prediction \
--num-classes $NUM_CLASSES \
--dropout 0.1 --attention-dropout 0.1 \
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
--clip-norm 0.0 \
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
--max-epoch 10 \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
```
For each of the GLUE task, you will need to use following cmd-line arguments:
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
---|---|---|---|---|---|---|---|---
`--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1
`--lr` | 1e-5 | 1e-5 | 1e-5 | 2e-5 | 1e-5 | 1e-5 | 1e-5 | 2e-5
`--max-sentences` | 32 | 32 | 32 | 16 | 32 | 16 | 16 | 16
`--total-num-update` | 123873 | 33112 | 113272 | 2036 | 20935 | 2296 | 5336 | 3598
`--warmup-updates` | 7432 | 1986 | 28318 | 122 | 1256 | 137 | 320 | 214
For `STS-B` additionally use following cmd-line argument:
```
--regression-target
--best-checkpoint-metric loss
```
and remove `--maximize-best-checkpoint-metric`.
**Note:**
a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--max-sentences=16/32` depending on the task.
b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`.
c) All the settings in above table are suggested settings based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search.
## Pretraining using your own data
......
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import contextlib
import sys
from collections import Counter
from multiprocessing import Pool
from fairseq.data.encoders.gpt2_bpe import get_encoder
def main():
"""
Helper script to encode raw text
with the GPT-2 BPE using multiple processes.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--encoder-json",
help='path to encoder.json',
)
parser.add_argument(
"--vocab-bpe",
type=str,
help='path to vocab.bpe',
)
parser.add_argument(
"--inputs",
nargs="+",
default=['-'],
help="input files to filter/encode",
)
parser.add_argument(
"--outputs",
nargs="+",
default=['-'],
help="path to save encoded outputs",
)
parser.add_argument(
"--keep-empty",
action="store_true",
help="keep empty lines",
)
parser.add_argument("--workers", type=int, default=20)
args = parser.parse_args()
assert len(args.inputs) == len(args.outputs), \
"number of input and output paths should match"
with contextlib.ExitStack() as stack:
inputs = [
stack.enter_context(open(input, "r", encoding="utf-8"))
if input != "-" else sys.stdin
for input in args.inputs
]
outputs = [
stack.enter_context(open(output, "w", encoding="utf-8"))
if output != "-" else sys.stdout
for output in args.outputs
]
encoder = MultiprocessingEncoder(args)
pool = Pool(args.workers, initializer=encoder.initializer)
encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100)
stats = Counter()
for i, (filt, enc_lines) in enumerate(encoded_lines, start=1):
if filt == "PASS":
for enc_line, output_h in zip(enc_lines, outputs):
print(enc_line, file=output_h)
else:
stats["num_filtered_" + filt] += 1
if i % 10000 == 0:
print("processed {} lines".format(i), file=sys.stderr)
for k, v in stats.most_common():
print("[{}] filtered {} lines".format(k, v), file=sys.stderr)
class MultiprocessingEncoder(object):
def __init__(self, args):
self.args = args
def initializer(self):
global bpe
bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe)
def encode(self, line):
global bpe
ids = bpe.encode(line)
return list(map(str, ids))
def decode(self, tokens):
global bpe
return bpe.decode(tokens)
def encode_lines(self, lines):
"""
Encode a set of lines. All lines will be encoded together.
"""
enc_lines = []
for line in lines:
line = line.strip()
if len(line) == 0 and not self.args.keep_empty:
return ["EMPTY", None]
tokens = self.encode(line)
enc_lines.append(" ".join(tokens))
return ["PASS", enc_lines]
def decode_lines(self, lines):
dec_lines = []
for line in lines:
tokens = map(int, line.strip().split())
dec_lines.append(self.decode(tokens))
return ["PASS", dec_lines]
if __name__ == "__main__":
main()
#!/bin/bash
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# raw glue data as downloaded by glue download script (https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
if [[ $# -ne 2 ]]; then
echo "Run as following:"
echo "./examples/roberta/preprocess_GLUE_tasks.sh <glud_data_folder> <task_name>"
exit 1
fi
GLUE_DATA_FOLDER=$1
# download bpe encoder.json, vocabulary and fairseq dictionary
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
TASKS=$2 # QQP
if [ "$TASKS" = "ALL" ]
then
TASKS="QQP MNLI QNLI MRPC RTE STS-B SST-2 CoLA"
fi
for TASK in $TASKS
do
echo "Preprocessing $TASK"
TASK_DATA_FOLDER="$GLUE_DATA_FOLDER/$TASK"
echo "Raw data as downloaded from glue website: $TASK_DATA_FOLDER"
SPLITS="train dev test"
INPUT_COUNT=2
if [ "$TASK" = "QQP" ]
then
INPUT_COLUMNS=( 4 5 )
TEST_INPUT_COLUMNS=( 2 3 )
LABEL_COLUMN=6
elif [ "$TASK" = "MNLI" ]
then
SPLITS="train dev_matched dev_mismatched test_matched test_mismatched"
INPUT_COLUMNS=( 9 10 )
TEST_INPUT_COLUMNS=( 9 10 )
DEV_LABEL_COLUMN=16
LABEL_COLUMN=12
elif [ "$TASK" = "QNLI" ]
then
INPUT_COLUMNS=( 2 3 )
TEST_INPUT_COLUMNS=( 2 3 )
LABEL_COLUMN=4
elif [ "$TASK" = "MRPC" ]
then
INPUT_COLUMNS=( 4 5 )
TEST_INPUT_COLUMNS=( 4 5 )
LABEL_COLUMN=1
elif [ "$TASK" = "RTE" ]
then
INPUT_COLUMNS=( 2 3 )
TEST_INPUT_COLUMNS=( 2 3 )
LABEL_COLUMN=4
elif [ "$TASK" = "STS-B" ]
then
INPUT_COLUMNS=( 8 9 )
TEST_INPUT_COLUMNS=( 8 9 )
LABEL_COLUMN=10
# Following are single sentence tasks.
elif [ "$TASK" = "SST-2" ]
then
INPUT_COLUMNS=( 1 )
TEST_INPUT_COLUMNS=( 2 )
LABEL_COLUMN=2
INPUT_COUNT=1
elif [ "$TASK" = "CoLA" ]
then
INPUT_COLUMNS=( 4 )
TEST_INPUT_COLUMNS=( 2 )
LABEL_COLUMN=2
INPUT_COUNT=1
fi
# Strip out header and filter lines that don't have expected number of fields.
rm -rf "$TASK_DATA_FOLDER/processed"
mkdir "$TASK_DATA_FOLDER/processed"
for SPLIT in $SPLITS
do
# CoLA train and dev doesn't have header.
if [[ ( "$TASK" = "CoLA") && ( "$SPLIT" != "test" ) ]]
then
cp "$TASK_DATA_FOLDER/$SPLIT.tsv" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
else
tail -n +2 "$TASK_DATA_FOLDER/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
fi
# Remove unformatted lines from train and dev files for QQP dataset.
if [[ ( "$TASK" = "QQP") && ( "$SPLIT" != "test" ) ]]
then
awk -F '\t' -v NUM_FIELDS=6 'NF==NUM_FIELDS{print}{}' "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";
else
cp "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";
fi
rm "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
done
# Split into input0, input1 and label
for SPLIT in $SPLITS
do
for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
do
if [[ "$SPLIT" != test* ]]
then
COLUMN_NUMBER=${INPUT_COLUMNS[$INPUT_TYPE]}
else
COLUMN_NUMBER=${TEST_INPUT_COLUMNS[$INPUT_TYPE]}
fi
cut -f"$COLUMN_NUMBER" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.raw.input$INPUT_TYPE";
done
if [[ "$SPLIT" != test* ]]
then
if [ "$TASK" = "MNLI" ] && [ "$SPLIT" != "train" ]
then
cut -f"$DEV_LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.label";
else
cut -f"$LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.label";
fi
fi
# BPE encode.
for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
do
LANG="input$INPUT_TYPE"
echo "BPE encoding $SPLIT/$LANG"
python -m examples.roberta.multiprocessing_bpe_encoder \
--encoder-json encoder.json \
--vocab-bpe vocab.bpe \
--inputs "$TASK_DATA_FOLDER/processed/$SPLIT.raw.$LANG" \
--outputs "$TASK_DATA_FOLDER/processed/$SPLIT.$LANG" \
--workers 60 \
--keep-empty;
done
done
# Remove output directory.
rm -rf "$TASK-bin"
DEVPREF="$TASK_DATA_FOLDER/processed/dev.LANG"
TESTPREF="$TASK_DATA_FOLDER/processed/test.LANG"
if [ "$TASK" = "MNLI" ]
then
DEVPREF="$TASK_DATA_FOLDER/processed/dev_matched.LANG,$TASK_DATA_FOLDER/processed/dev_mismatched.LANG"
TESTPREF="$TASK_DATA_FOLDER/processed/test_matched.LANG,$TASK_DATA_FOLDER/processed/test_mismatched.LANG"
fi
# Run fairseq preprocessing:
for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
do
LANG="input$INPUT_TYPE"
python preprocess.py \
--only-source \
--trainpref "$TASK_DATA_FOLDER/processed/train.$LANG" \
--validpref "${DEVPREF//LANG/$LANG}" \
--testpref "${TESTPREF//LANG/$LANG}" \
--destdir "$TASK-bin/$LANG" \
--workers 60 \
--srcdict dict.txt;
done
if [[ "$TASK" != "STS-B" ]]
then
python preprocess.py \
--only-source \
--trainpref "$TASK_DATA_FOLDER/processed/train.label" \
--validpref "${DEVPREF//LANG/'label'}" \
--destdir "$TASK-bin/label" \
--workers 60;
else
# For STS-B output range is converted to be between: [0.0, 1.0]
mkdir "$TASK-bin/label"
awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/train.label" > "$TASK-bin/label/train.label"
awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/dev.label" > "$TASK-bin/label/valid.label"
fi
done
......@@ -24,11 +24,16 @@ from fairseq.models import FairseqEncoder, FairseqDecoder
def save_checkpoint(args, trainer, epoch_itr, val_loss):
from fairseq import distributed_utils, meters
prev_best = getattr(save_checkpoint, 'best', val_loss)
if val_loss is not None:
best_function = max if args.maximize_best_checkpoint_metric else min
save_checkpoint.best = best_function(val_loss, prev_best)
if args.no_save or not distributed_utils.is_master(args):
return
def is_better(a, b):
return a > b if args.maximize_best_checkpoint_metric else a < b
return a >= b if args.maximize_best_checkpoint_metric else a <= b
write_timer = meters.StopwatchMeter()
write_timer.start()
......@@ -52,9 +57,6 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
)
checkpoint_conds['checkpoint_last.pt'] = not args.no_last_checkpoints
prev_best = getattr(save_checkpoint, 'best', val_loss)
if val_loss is not None:
save_checkpoint.best = val_loss if is_better(val_loss, prev_best) else prev_best
extra_state = {
'train_iterator': epoch_itr.state_dict(),
'val_loss': val_loss,
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
import torch.nn.functional as F
from fairseq import utils
from . import FairseqCriterion, register_criterion
@register_criterion('sentence_prediction')
class SentencePredictionCriterion(FairseqCriterion):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--save-predictions', metavar='FILE',
help='file to save predictions to')
# fmt: on
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
features, extra = model(**sample['net_input'], features_only=True)
padding_mask = sample['net_input']['src_tokens'].eq(self.padding_idx)
assert hasattr(model, 'classification_heads') and \
'sentence_classification_head' in model.classification_heads, \
"model must provide sentence classification head for --criterion=sentence_prediction"
logits = model.classification_heads['sentence_classification_head'](
features,
padding_mask=padding_mask,
)
targets = model.get_targets(sample, [logits]).view(-1)
sample_size = targets.numel()
if not self.args.regression_target:
loss = F.nll_loss(
F.log_softmax(logits, dim=-1, dtype=torch.float32),
targets,
reduction='sum',
)
else:
logits = logits.squeeze().float()
targets = targets.float()
loss = F.mse_loss(
logits,
targets,
reduction='sum',
)
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample_size,
'sample_size': sample_size,
}
if not self.args.regression_target:
preds = logits.max(dim=1)[1]
logging_output.update(
ncorrect=(preds == targets).sum().item()
)
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
if len(logging_outputs) > 0 and 'ncorrect' in logging_outputs[0]:
ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs)
agg_output.update(accuracy=ncorrect/nsentences)
if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
return agg_output
......@@ -14,6 +14,7 @@ from .base_wrapper_dataset import BaseWrapperDataset
from .audio.raw_audio_dataset import RawAudioDataset
from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset
from .concat_sentences_dataset import ConcatSentencesDataset
from .id_dataset import IdDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset
from .language_pair_dataset import LanguagePairDataset
......@@ -25,13 +26,17 @@ from .nested_dictionary_dataset import NestedDictionaryDataset
from .noising import NoisingDataset
from .numel_dataset import NumelDataset
from .num_samples_dataset import NumSamplesDataset
from .offset_tokens_dataset import OffsetTokensDataset
from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
from .prepend_token_dataset import PrependTokenDataset
from .raw_label_dataset import RawLabelDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .sort_dataset import SortDataset
from .strip_token_dataset import StripTokenDataset
from .token_block_dataset import TokenBlockDataset
from .transform_eos_dataset import TransformEosDataset
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
from .truncate_dataset import TruncateDataset
from .iterators import (
CountingIterator,
......@@ -44,6 +49,7 @@ __all__ = [
'BacktranslationDataset',
'BaseWrapperDataset',
'ConcatDataset',
'ConcatSentencesDataset',
'CountingIterator',
'Dictionary',
'EpochBatchIterator',
......@@ -64,15 +70,19 @@ __all__ = [
'NoisingDataset',
'NumelDataset',
'NumSamplesDataset',
"OffsetTokensDataset",
'PadDataset',
'PrependTokenDataset',
'RawAudioDataset',
"RawLabelDataset",
'RightPadDataset',
'RoundRobinZipDatasets',
'ShardedIterator',
'SortDataset',
"StripTokenDataset",
'TokenBlockDataset',
'TransformEosDataset',
'TransformEosLangPairDataset',
"TruncateDataset",
'TruncatedDictionary',
]
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from . import FairseqDataset
class ConcatSentencesDataset(FairseqDataset):
def __init__(self, *datasets):
super().__init__()
self.datasets = datasets
assert all(len(ds) == len(datasets[0]) for ds in datasets), \
'datasets must have the same length'
def __getitem__(self, index):
return torch.cat([ds[index] for ds in self.datasets])
def __len__(self):
return len(self.datasets[0])
def collater(self, samples):
return self.datasets[0].collater(samples)
@property
def sizes(self):
return sum(ds.sizes for ds in self.datasets)
def num_tokens(self, index):
return sum(ds.num_tokens(index) for ds in self.datasets)
def size(self, index):
return sum(ds.size(index) for ds in self.datasets)
def ordered_indices(self):
return self.datasets[0].ordered_indices()
@property
def supports_prefetch(self):
return any(
getattr(ds, 'supports_prefetch', False) for ds in self.datasets
)
def prefetch(self, indices):
for ds in self.datasets:
if getattr(ds, 'supports_prefetch', False):
ds.prefetch(indices)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from . import BaseWrapperDataset
class OffsetTokensDataset(BaseWrapperDataset):
def __init__(self, dataset, offset):
super().__init__(dataset)
self.offset = offset
def __getitem__(self, idx):
return self.dataset[idx] + self.offset
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from . import FairseqDataset
class RawLabelDataset(FairseqDataset):
def __init__(self, labels):
super().__init__()
self.labels = labels
def __getitem__(self, index):
return self.labels[index]
def __len__(self):
return len(self.labels)
def collater(self, samples):
return torch.tensor(samples)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from . import BaseWrapperDataset
class StripTokenDataset(BaseWrapperDataset):
def __init__(self, dataset, id_to_strip):
super().__init__(dataset)
self.id_to_strip = id_to_strip
def __getitem__(self, index):
item = self.dataset[index]
return item[item.ne(self.id_to_strip)]
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import numpy as np
from . import BaseWrapperDataset
class TruncateDataset(BaseWrapperDataset):
def __init__(self, dataset, truncation_length):
super().__init__(dataset)
self.truncation_length = truncation_length
self.dataset = dataset
def __getitem__(self, index):
item = self.dataset[index]
item_len = item.size(0)
if item_len > self.truncation_length:
item = item[:self.truncation_length]
return item
@property
def sizes(self):
return np.minimum(self.dataset.sizes, self.truncation_length)
def __len__(self):
return len(self.dataset)
......@@ -134,6 +134,15 @@ class RobertaModel(FairseqLanguageModel):
].size(0)
self.register_classification_head(head_name, num_classes, inner_dim)
# Copy any newly-added classification heads into the state dict
# with their current weights.
if hasattr(self, 'classification_heads'):
cur_state = self.classification_heads.state_dict()
for k, v in cur_state.items():
if prefix + 'classification_heads.' + k not in state_dict:
print('Overwriting', prefix + 'classification_heads.' + k)
state_dict[prefix + 'classification_heads.' + k] = v
class RobertaLMHead(nn.Module):
"""Head for masked language modeling."""
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import numpy as np
from fairseq.data import (
ConcatSentencesDataset,
data_utils,
Dictionary,
IdDataset,
NestedDictionaryDataset,
NumSamplesDataset,
NumelDataset,
OffsetTokensDataset,
PrependTokenDataset,
RawLabelDataset,
RightPadDataset,
SortDataset,
StripTokenDataset,
TruncateDataset,
)
from . import FairseqTask, register_task
@register_task('sentence_prediction')
class SentencePredictionTask(FairseqTask):
"""
Sentence (or sentence pair) prediction (classification or regression) task.
Args:
dictionary (Dictionary): the dictionary for the input of the task
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='FILE',
help='file prefix for data')
parser.add_argument('--max-positions', type=int, default=512,
help='max input length')
parser.add_argument('--num-classes', type=int, default=-1,
help='number of classes')
parser.add_argument('--init-token', type=int, default=None,
help='add token at the beginning of each batch item')
parser.add_argument('--separator-token', type=int, default=None,
help='add separator token between inputs')
parser.add_argument('--regression-target', action='store_true', default=False)
parser.add_argument('--no-shuffle', action='store_true', default=False)
parser.add_argument('--truncate-sequence', action='store_true', default=False,
help='Truncate sequence to max_sequence_length')
def __init__(self, args, data_dictionary, label_dictionary):
super().__init__(args)
self.dictionary = data_dictionary
self.label_dictionary = label_dictionary
@classmethod
def load_dictionary(cls, args, filename, source=True):
"""Load the dictionary from the filename
Args:
filename (str): the filename
"""
dictionary = Dictionary.load(filename)
dictionary.add_symbol('<mask>')
return dictionary
@classmethod
def setup_task(cls, args, **kwargs):
assert args.num_classes > 0, 'Must set --num-classes'
args.tokens_per_sample = args.max_positions
# load data dictionary
data_dict = cls.load_dictionary(
args,
os.path.join(args.data, 'input0', 'dict.txt'),
source=True,
)
print('| [input] dictionary: {} types'.format(len(data_dict)))
label_dict = None
if not args.regression_target:
# load label dictionary
label_dict = cls.load_dictionary(
args,
os.path.join(args.data, 'label', 'dict.txt'),
source=False,
)
print('| [label] dictionary: {} types'.format(len(label_dict)))
else:
label_dict = data_dict
return SentencePredictionTask(args, data_dict, label_dict)
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split (e.g., train, valid, test)."""
def get_path(type, split):
return os.path.join(self.args.data, type, split)
def make_dataset(type, dictionary):
split_path = get_path(type, split)
dataset = data_utils.load_indexed_dataset(
split_path,
self.source_dictionary,
self.args.dataset_impl,
combine=combine,
)
return dataset
input0 = make_dataset('input0', self.source_dictionary)
assert input0 is not None, 'could not find dataset: {}'.format(get_path(type, split))
input1 = make_dataset('input1', self.source_dictionary)
if self.args.init_token is not None:
input0 = PrependTokenDataset(input0, self.args.init_token)
if input1 is None:
src_tokens = input0
else:
if self.args.separator_token is not None:
input1 = PrependTokenDataset(input1, self.args.separator_token)
src_tokens = ConcatSentencesDataset(input0, input1)
with data_utils.numpy_seed(self.args.seed):
shuffle = np.random.permutation(len(src_tokens))
if self.args.truncate_sequence:
src_tokens = TruncateDataset(src_tokens, self.args.max_positions)
dataset = {
'id': IdDataset(),
'net_input': {
'src_tokens': RightPadDataset(
src_tokens,
pad_idx=self.source_dictionary.pad(),
),
'src_lengths': NumelDataset(src_tokens, reduce=False),
},
'nsentences': NumSamplesDataset(),
'ntokens': NumelDataset(src_tokens, reduce=True),
}
if not self.args.regression_target:
label_dataset = make_dataset('label', self.target_dictionary)
if label_dataset is not None:
dataset.update(
target=OffsetTokensDataset(
StripTokenDataset(
label_dataset,
id_to_strip=self.target_dictionary.eos(),
),
offset=-self.target_dictionary.nspecial,
)
)
else:
label_path = f"{get_path('label', split)}.label"
if os.path.exists(label_path):
dataset.update(
target=RawLabelDataset([
float(x.strip()) for x in open(label_path).readlines()
])
)
nested_dataset = NestedDictionaryDataset(
dataset,
sizes=[src_tokens.sizes],
)
if self.args.no_shuffle:
dataset = nested_dataset
else:
dataset = SortDataset(
nested_dataset,
# shuffle
sort_order=[shuffle],
)
print(f"| Loaded {split} with #samples: {len(dataset)}")
self.datasets[split] = dataset
return self.datasets[split]
def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)
model.register_classification_head(
'sentence_classification_head',
num_classes=self.args.num_classes,
)
return model
def max_positions(self):
return self.args.max_positions
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.label_dictionary
......@@ -130,7 +130,7 @@ def train(args, trainer, task, epoch_itr):
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
continue # these are already logged above
if 'loss' in k:
if 'loss' in k or k == 'accuracy':
extra_meters[k].update(v, log_output['sample_size'])
else:
extra_meters[k].update(v)
......@@ -236,16 +236,20 @@ def validate(args, trainer, task, epoch_itr, subsets):
extra_meters[k].update(v)
# log validation stats
stats = get_valid_stats(trainer)
stats = get_valid_stats(trainer, args, extra_meters)
for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats, tag=subset, step=trainer.get_num_updates())
valid_losses.append(stats[args.best_checkpoint_metric].avg)
valid_losses.append(
stats[args.best_checkpoint_metric].avg
if args.best_checkpoint_metric == 'loss'
else stats[args.best_checkpoint_metric]
)
return valid_losses
def get_valid_stats(trainer):
def get_valid_stats(trainer, args, extra_meters=None):
stats = collections.OrderedDict()
stats['loss'] = trainer.get_meter('valid_loss')
if trainer.get_meter('valid_nll_loss').count > 0:
......@@ -256,8 +260,23 @@ def get_valid_stats(trainer):
stats['ppl'] = utils.get_perplexity(nll_loss.avg)
stats['num_updates'] = trainer.get_num_updates()
if hasattr(checkpoint_utils.save_checkpoint, 'best'):
stats['best_loss'] = min(
checkpoint_utils.save_checkpoint.best, stats['loss'].avg)
key = f'best_{args.best_checkpoint_metric}'
best_function = max if args.maximize_best_checkpoint_metric else min
current_metric = None
if args.best_checkpoint_metric == 'loss':
current_metric = stats['loss'].avg
elif args.best_checkpoint_metric in extra_meters:
current_metric = extra_meters[args.best_checkpoint_metric].avg
elif args.best_checkpoint_metric in stats:
current_metric = stats[args.best_checkpoint_metric]
else:
raise ValueError("best_checkpoint_metric not found in logs")
stats[key] = best_function(
checkpoint_utils.save_checkpoint.best,
current_metric,
)
return stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment