Commit c394d7d1 authored by “change”'s avatar “change”
Browse files

init

parents
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# raw glue data as downloaded by glue download script (https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
if [[ $# -ne 2 ]]; then
echo "Run as following:"
echo "./examples/roberta/preprocess_GLUE_tasks.sh <glud_data_folder> <task_name>"
exit 1
fi
GLUE_DATA_FOLDER=$1
# download bpe encoder.json, vocabulary and fairseq dictionary
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
TASKS=$2 # QQP
if [ "$TASKS" = "ALL" ]
then
TASKS="QQP MNLI QNLI MRPC RTE STS-B SST-2 CoLA"
fi
for TASK in $TASKS
do
echo "Preprocessing $TASK"
TASK_DATA_FOLDER="$GLUE_DATA_FOLDER/$TASK"
echo "Raw data as downloaded from glue website: $TASK_DATA_FOLDER"
SPLITS="train dev test"
INPUT_COUNT=2
if [ "$TASK" = "QQP" ]
then
INPUT_COLUMNS=( 4 5 )
TEST_INPUT_COLUMNS=( 2 3 )
LABEL_COLUMN=6
elif [ "$TASK" = "MNLI" ]
then
SPLITS="train dev_matched dev_mismatched test_matched test_mismatched"
INPUT_COLUMNS=( 9 10 )
TEST_INPUT_COLUMNS=( 9 10 )
DEV_LABEL_COLUMN=16
LABEL_COLUMN=12
elif [ "$TASK" = "QNLI" ]
then
INPUT_COLUMNS=( 2 3 )
TEST_INPUT_COLUMNS=( 2 3 )
LABEL_COLUMN=4
elif [ "$TASK" = "MRPC" ]
then
INPUT_COLUMNS=( 4 5 )
TEST_INPUT_COLUMNS=( 4 5 )
LABEL_COLUMN=1
elif [ "$TASK" = "RTE" ]
then
INPUT_COLUMNS=( 2 3 )
TEST_INPUT_COLUMNS=( 2 3 )
LABEL_COLUMN=4
elif [ "$TASK" = "STS-B" ]
then
INPUT_COLUMNS=( 8 9 )
TEST_INPUT_COLUMNS=( 8 9 )
LABEL_COLUMN=10
# Following are single sentence tasks.
elif [ "$TASK" = "SST-2" ]
then
INPUT_COLUMNS=( 1 )
TEST_INPUT_COLUMNS=( 2 )
LABEL_COLUMN=2
INPUT_COUNT=1
elif [ "$TASK" = "CoLA" ]
then
INPUT_COLUMNS=( 4 )
TEST_INPUT_COLUMNS=( 2 )
LABEL_COLUMN=2
INPUT_COUNT=1
fi
# Strip out header and filter lines that don't have expected number of fields.
rm -rf "$TASK_DATA_FOLDER/processed"
mkdir -p "$TASK_DATA_FOLDER/processed"
for SPLIT in $SPLITS
do
# CoLA train and dev doesn't have header.
if [[ ( "$TASK" = "CoLA") && ( "$SPLIT" != "test" ) ]]
then
cp "$TASK_DATA_FOLDER/$SPLIT.tsv" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
else
tail -n +2 "$TASK_DATA_FOLDER/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
fi
# Remove unformatted lines from train and dev files for QQP dataset.
if [[ ( "$TASK" = "QQP") && ( "$SPLIT" != "test" ) ]]
then
awk -F '\t' -v NUM_FIELDS=6 'NF==NUM_FIELDS{print}{}' "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";
else
cp "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";
fi
rm "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
done
# Split into input0, input1 and label
for SPLIT in $SPLITS
do
for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
do
if [[ "$SPLIT" != test* ]]
then
COLUMN_NUMBER=${INPUT_COLUMNS[$INPUT_TYPE]}
else
COLUMN_NUMBER=${TEST_INPUT_COLUMNS[$INPUT_TYPE]}
fi
cut -f"$COLUMN_NUMBER" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.raw.input$INPUT_TYPE";
done
if [[ "$SPLIT" != test* ]]
then
if [ "$TASK" = "MNLI" ] && [ "$SPLIT" != "train" ]
then
cut -f"$DEV_LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.label";
else
cut -f"$LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.label";
fi
fi
# BPE encode.
for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
do
LANG="input$INPUT_TYPE"
echo "BPE encoding $SPLIT/$LANG"
python -m examples.roberta.multiprocessing_bpe_encoder \
--encoder-json encoder.json \
--vocab-bpe vocab.bpe \
--inputs "$TASK_DATA_FOLDER/processed/$SPLIT.raw.$LANG" \
--outputs "$TASK_DATA_FOLDER/processed/$SPLIT.$LANG" \
--workers 60 \
--keep-empty;
done
done
# Remove output directory.
rm -rf "$TASK-bin"
DEVPREF="$TASK_DATA_FOLDER/processed/dev.LANG"
TESTPREF="$TASK_DATA_FOLDER/processed/test.LANG"
if [ "$TASK" = "MNLI" ]
then
DEVPREF="$TASK_DATA_FOLDER/processed/dev_matched.LANG,$TASK_DATA_FOLDER/processed/dev_mismatched.LANG"
TESTPREF="$TASK_DATA_FOLDER/processed/test_matched.LANG,$TASK_DATA_FOLDER/processed/test_mismatched.LANG"
fi
# Run fairseq preprocessing:
for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
do
LANG="input$INPUT_TYPE"
fairseq-preprocess \
--only-source \
--trainpref "$TASK_DATA_FOLDER/processed/train.$LANG" \
--validpref "${DEVPREF//LANG/$LANG}" \
--testpref "${TESTPREF//LANG/$LANG}" \
--destdir "$TASK-bin/$LANG" \
--workers 60 \
--srcdict dict.txt;
done
if [[ "$TASK" != "STS-B" ]]
then
fairseq-preprocess \
--only-source \
--trainpref "$TASK_DATA_FOLDER/processed/train.label" \
--validpref "${DEVPREF//LANG/label}" \
--destdir "$TASK-bin/label" \
--workers 60;
else
# For STS-B output range is converted to be between: [0.0, 1.0]
mkdir -p "$TASK-bin/label"
awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/train.label" > "$TASK-bin/label/train.label"
awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/dev.label" > "$TASK-bin/label/valid.label"
fi
done
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import os
import re
class InputExample:
def __init__(self, paragraph, qa_list, label):
self.paragraph = paragraph
self.qa_list = qa_list
self.label = label
def get_examples(data_dir, set_type):
"""
Extract paragraph and question-answer list from each json file
"""
examples = []
levels = ["middle", "high"]
set_type_c = set_type.split("-")
if len(set_type_c) == 2:
levels = [set_type_c[1]]
set_type = set_type_c[0]
for level in levels:
cur_dir = os.path.join(data_dir, set_type, level)
for filename in os.listdir(cur_dir):
cur_path = os.path.join(cur_dir, filename)
with open(cur_path, "r") as f:
cur_data = json.load(f)
answers = cur_data["answers"]
options = cur_data["options"]
questions = cur_data["questions"]
context = cur_data["article"].replace("\n", " ")
context = re.sub(r"\s+", " ", context)
for i in range(len(answers)):
label = ord(answers[i]) - ord("A")
qa_list = []
question = questions[i]
for j in range(4):
option = options[i][j]
if "_" in question:
qa_cat = question.replace("_", option)
else:
qa_cat = " ".join([question, option])
qa_cat = re.sub(r"\s+", " ", qa_cat)
qa_list.append(qa_cat)
examples.append(InputExample(context, qa_list, label))
return examples
def main():
"""
Helper script to extract paragraphs questions and answers from RACE datasets.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--input-dir",
help="input directory for downloaded RACE dataset",
)
parser.add_argument(
"--output-dir",
help="output directory for extracted data",
)
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
for set_type in ["train", "dev", "test-middle", "test-high"]:
examples = get_examples(args.input_dir, set_type)
qa_file_paths = [
os.path.join(args.output_dir, set_type + ".input" + str(i + 1))
for i in range(4)
]
qa_files = [open(qa_file_path, "w") for qa_file_path in qa_file_paths]
outf_context_path = os.path.join(args.output_dir, set_type + ".input0")
outf_label_path = os.path.join(args.output_dir, set_type + ".label")
outf_context = open(outf_context_path, "w")
outf_label = open(outf_label_path, "w")
for example in examples:
outf_context.write(example.paragraph + "\n")
for i in range(4):
qa_files[i].write(example.qa_list[i] + "\n")
outf_label.write(str(example.label) + "\n")
for f in qa_files:
f.close()
outf_label.close()
outf_context.close()
if __name__ == "__main__":
main()
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# data should be downloaded and processed with reprocess_RACE.py
if [[ $# -ne 2 ]]; then
echo "Run as following:"
echo "./examples/roberta/preprocess_RACE.sh <race_data_folder> <output_folder>"
exit 1
fi
RACE_DATA_FOLDER=$1
OUT_DATA_FOLDER=$2
# download bpe encoder.json, vocabulary and fairseq dictionary
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
SPLITS="train dev test-middle test-high"
INPUT_TYPES="input0 input1 input2 input3 input4"
for INPUT_TYPE in $INPUT_TYPES
do
for SPLIT in $SPLITS
do
echo "BPE encoding $SPLIT/$INPUT_TYPE"
python -m examples.roberta.multiprocessing_bpe_encoder \
--encoder-json encoder.json \
--vocab-bpe vocab.bpe \
--inputs "$RACE_DATA_FOLDER/$SPLIT.$INPUT_TYPE" \
--outputs "$RACE_DATA_FOLDER/$SPLIT.$INPUT_TYPE.bpe" \
--workers 10 \
--keep-empty;
done
done
for INPUT_TYPE in $INPUT_TYPES
do
LANG="input$INPUT_TYPE"
fairseq-preprocess \
--only-source \
--trainpref "$RACE_DATA_FOLDER/train.$INPUT_TYPE.bpe" \
--validpref "$RACE_DATA_FOLDER/dev.$INPUT_TYPE.bpe" \
--testpref "$RACE_DATA_FOLDER/test-middle.$INPUT_TYPE.bpe,$RACE_DATA_FOLDER/test-high.$INPUT_TYPE.bpe" \
--destdir "$OUT_DATA_FOLDER/$INPUT_TYPE" \
--workers 10 \
--srcdict dict.txt;
done
rm -rf "$OUT_DATA_FOLDER/label"
mkdir -p "$OUT_DATA_FOLDER/label"
cp "$RACE_DATA_FOLDER/train.label" "$OUT_DATA_FOLDER/label/"
cp "$RACE_DATA_FOLDER/dev.label" "$OUT_DATA_FOLDER/label/valid.label"
cp "$RACE_DATA_FOLDER/test-middle.label" "$OUT_DATA_FOLDER/label/test.label"
cp "$RACE_DATA_FOLDER/test-high.label" "$OUT_DATA_FOLDER/label/test1.label"
# Finetuning RoBERTa on Winograd Schema Challenge (WSC) data
The following instructions can be used to finetune RoBERTa on the WSC training
data provided by [SuperGLUE](https://super.gluebenchmark.com/).
Note that there is high variance in the results. For our GLUE/SuperGLUE
submission we swept over the learning rate (1e-5, 2e-5, 3e-5), batch size (16,
32, 64) and total number of updates (500, 1000, 2000, 3000), as well as the
random seed. Out of ~100 runs we chose the best 7 models and ensembled them.
**Approach:** The instructions below use a slightly different loss function than
what's described in the original RoBERTa arXiv paper. In particular,
[Kocijan et al. (2019)](https://arxiv.org/abs/1905.06290) introduce a margin
ranking loss between `(query, candidate)` pairs with tunable hyperparameters
alpha and beta. This is supported in our code as well with the `--wsc-alpha` and
`--wsc-beta` arguments. However, we achieved slightly better (and more robust)
results on the development set by instead using a single cross entropy loss term
over the log-probabilities for the query and all mined candidates. **The
candidates are mined using spaCy from each input sentence in isolation, so the
approach remains strictly pointwise.** This reduces the number of
hyperparameters and our best model achieved 92.3% development set accuracy,
compared to ~90% accuracy for the margin loss. Later versions of the RoBERTa
arXiv paper will describe this updated formulation.
### 1) Download the WSC data from the SuperGLUE website:
```bash
wget https://dl.fbaipublicfiles.com/glue/superglue/data/v2/WSC.zip
unzip WSC.zip
# we also need to copy the RoBERTa dictionary into the same directory
wget -O WSC/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt
```
### 2) Finetune over the provided training data:
```bash
TOTAL_NUM_UPDATES=2000 # Total number of training steps.
WARMUP_UPDATES=250 # Linearly increase LR over this many steps.
LR=2e-05 # Peak LR for polynomial LR scheduler.
MAX_SENTENCES=16 # Batch size per GPU.
SEED=1 # Random seed.
ROBERTA_PATH=/path/to/roberta/model.pt
# we use the --user-dir option to load the task and criterion
# from the examples/roberta/wsc directory:
FAIRSEQ_PATH=/path/to/fairseq
FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/wsc
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train WSC/ \
--restore-file $ROBERTA_PATH \
--reset-optimizer --reset-dataloader --reset-meters \
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
--valid-subset val \
--fp16 --ddp-backend legacy_ddp \
--user-dir $FAIRSEQ_USER_DIR \
--task wsc --criterion wsc --wsc-cross-entropy \
--arch roberta_large --bpe gpt2 --max-positions 512 \
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
--lr-scheduler polynomial_decay --lr $LR \
--warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \
--batch-size $MAX_SENTENCES \
--max-update $TOTAL_NUM_UPDATES \
--log-format simple --log-interval 100 \
--seed $SEED
```
The above command assumes training on 4 GPUs, but you can achieve the same
results on a single GPU by adding `--update-freq=4`.
### 3) Evaluate
```python
from fairseq.models.roberta import RobertaModel
from examples.roberta.wsc import wsc_utils # also loads WSC task and criterion
roberta = RobertaModel.from_pretrained('checkpoints', 'checkpoint_best.pt', 'WSC/')
roberta.cuda()
nsamples, ncorrect = 0, 0
for sentence, label in wsc_utils.jsonl_iterator('WSC/val.jsonl', eval=True):
pred = roberta.disambiguate_pronoun(sentence)
nsamples += 1
if pred == label:
ncorrect += 1
print('Accuracy: ' + str(ncorrect / float(nsamples)))
# Accuracy: 0.9230769230769231
```
## RoBERTa training on WinoGrande dataset
We have also provided `winogrande` task and criterion for finetuning on the
[WinoGrande](https://mosaic.allenai.org/projects/winogrande) like datasets
where there are always two candidates and one is correct.
It's more efficient implementation for such subcases.
```bash
TOTAL_NUM_UPDATES=23750 # Total number of training steps.
WARMUP_UPDATES=2375 # Linearly increase LR over this many steps.
LR=1e-05 # Peak LR for polynomial LR scheduler.
MAX_SENTENCES=32 # Batch size per GPU.
SEED=1 # Random seed.
ROBERTA_PATH=/path/to/roberta/model.pt
# we use the --user-dir option to load the task and criterion
# from the examples/roberta/wsc directory:
FAIRSEQ_PATH=/path/to/fairseq
FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/wsc
cd fairseq
CUDA_VISIBLE_DEVICES=0 fairseq-train winogrande_1.0/ \
--restore-file $ROBERTA_PATH \
--reset-optimizer --reset-dataloader --reset-meters \
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
--valid-subset val \
--fp16 --ddp-backend legacy_ddp \
--user-dir $FAIRSEQ_USER_DIR \
--task winogrande --criterion winogrande \
--wsc-margin-alpha 5.0 --wsc-margin-beta 0.4 \
--arch roberta_large --bpe gpt2 --max-positions 512 \
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
--lr-scheduler polynomial_decay --lr $LR \
--warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \
--batch-size $MAX_SENTENCES \
--max-update $TOTAL_NUM_UPDATES \
--log-format simple --log-interval 100
```
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import wsc_criterion # noqa
from . import wsc_task # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
from fairseq.data import encoders
@register_criterion("wsc")
class WSCCriterion(LegacyFairseqCriterion):
def __init__(self, args, task):
super().__init__(args, task)
if self.args.save_predictions is not None:
self.prediction_h = open(self.args.save_predictions, "w")
else:
self.prediction_h = None
self.bpe = encoders.build_bpe(args.bpe)
self.tokenizer = encoders.build_tokenizer(args.tokenizer)
def __del__(self):
if self.prediction_h is not None:
self.prediction_h.close()
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
parser.add_argument("--wsc-margin-alpha", type=float, metavar="A", default=1.0)
parser.add_argument("--wsc-margin-beta", type=float, metavar="B", default=0.0)
parser.add_argument(
"--wsc-cross-entropy",
action="store_true",
help="use cross entropy formulation instead of margin loss",
)
parser.add_argument(
"--save-predictions", metavar="FILE", help="file to save predictions to"
)
def get_masked_input(self, tokens, mask):
masked_tokens = tokens.clone()
masked_tokens[mask] = self.task.mask
return masked_tokens
def get_lprobs(self, model, tokens, mask):
logits, _ = model(src_tokens=self.get_masked_input(tokens, mask))
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
mask = mask.type_as(scores)
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
return scores
def get_loss(self, query_lprobs, cand_lprobs):
if self.args.wsc_cross_entropy:
return F.cross_entropy(
torch.cat([query_lprobs, cand_lprobs]).unsqueeze(0),
query_lprobs.new([0]).long(),
)
else:
return (
-query_lprobs
+ self.args.wsc_margin_alpha
* (cand_lprobs - query_lprobs + self.args.wsc_margin_beta).clamp(min=0)
).sum()
def forward(self, model, sample, reduce=True):
# compute loss and accuracy
loss, nloss = 0.0, 0
ncorrect, nqueries = 0, 0
for i, label in enumerate(sample["labels"]):
query_lprobs = self.get_lprobs(
model,
sample["query_tokens"][i].unsqueeze(0),
sample["query_masks"][i].unsqueeze(0),
)
cand_lprobs = self.get_lprobs(
model,
sample["candidate_tokens"][i],
sample["candidate_masks"][i],
)
pred = (query_lprobs >= cand_lprobs).all().item()
if label is not None:
label = 1 if label else 0
ncorrect += 1 if pred == label else 0
nqueries += 1
if label:
# only compute a loss for positive instances
nloss += 1
loss += self.get_loss(query_lprobs, cand_lprobs)
id = sample["id"][i].item()
if self.prediction_h is not None:
print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
if nloss == 0:
loss = torch.tensor(0.0, requires_grad=True)
sample_size = nqueries if nqueries > 0 else 1
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"ncorrect": ncorrect,
"nqueries": nqueries,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_output = {
"loss": loss_sum / sample_size / math.log(2),
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
nqueries = sum(log.get("nqueries", 0) for log in logging_outputs)
if nqueries > 0:
agg_output["accuracy"] = ncorrect / float(nqueries)
return agg_output
@register_criterion("winogrande")
class WinograndeCriterion(WSCCriterion):
def forward(self, model, sample, reduce=True):
# compute loss and accuracy
query_lprobs = self.get_lprobs(
model,
sample["query_tokens"],
sample["query_masks"],
)
cand_lprobs = self.get_lprobs(
model,
sample["candidate_tokens"],
sample["candidate_masks"],
)
pred = query_lprobs >= cand_lprobs
loss = self.get_loss(query_lprobs, cand_lprobs)
sample_size = sample["query_tokens"].size(0)
ncorrect = pred.sum().item()
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"ncorrect": ncorrect,
"nqueries": sample_size,
}
return loss, sample_size, logging_output
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import tempfile
import numpy as np
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import (
Dictionary,
IdDataset,
ListDataset,
NestedDictionaryDataset,
NumelDataset,
NumSamplesDataset,
PadDataset,
SortDataset,
data_utils,
encoders,
)
from fairseq.tasks import LegacyFairseqTask, register_task
from . import wsc_utils
@register_task("wsc")
class WSCTask(LegacyFairseqTask):
"""Task to finetune RoBERTa for Winograd Schemas."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument(
"data", metavar="DIR", help="path to data directory; we load <split>.jsonl"
)
parser.add_argument(
"--init-token",
type=int,
default=None,
help="add token at the beginning of each batch item",
)
def __init__(self, args, vocab):
super().__init__(args)
self.vocab = vocab
self.mask = vocab.add_symbol("<mask>")
self.bpe = encoders.build_bpe(args)
self.tokenizer = encoders.build_tokenizer(args)
# hack to handle GPT-2 BPE, which includes leading spaces
if args.bpe == "gpt2":
self.leading_space = True
self.trailing_space = False
else:
self.leading_space = False
self.trailing_space = True
@classmethod
def load_dictionary(cls, filename):
"""Load the dictionary from the filename
Args:
filename (str): the filename
"""
dictionary = Dictionary.load(filename)
dictionary.add_symbol("<mask>")
return dictionary
@classmethod
def setup_task(cls, args, **kwargs):
assert args.criterion == "wsc", "Must set --criterion=wsc"
# load data and label dictionaries
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
print("| dictionary: {} types".format(len(vocab)))
return cls(args, vocab)
def binarize(self, s: str, append_eos: bool = False):
if self.tokenizer is not None:
s = self.tokenizer.encode(s)
if self.bpe is not None:
s = self.bpe.encode(s)
tokens = self.vocab.encode_line(
s,
append_eos=append_eos,
add_if_not_exist=False,
).long()
if self.args.init_token is not None:
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
return tokens
def binarize_with_mask(self, txt, prefix, suffix, leading_space, trailing_space):
toks = self.binarize(
prefix + leading_space + txt + trailing_space + suffix,
append_eos=True,
)
mask = torch.zeros_like(toks, dtype=torch.bool)
mask_start = len(self.binarize(prefix))
mask_size = len(self.binarize(leading_space + txt))
mask[mask_start : mask_start + mask_size] = 1
return toks, mask
def load_dataset(
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if data_path is None:
data_path = os.path.join(self.args.data, split + ".jsonl")
if not os.path.exists(data_path):
raise FileNotFoundError("Cannot find data: {}".format(data_path))
query_tokens = []
query_masks = []
query_lengths = []
candidate_tokens = []
candidate_masks = []
candidate_lengths = []
labels = []
for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path):
prefix = sentence[: pronoun_span.start].text
suffix = sentence[pronoun_span.end :].text_with_ws
# spaCy spans include trailing spaces, but we need to know about
# leading spaces for the GPT-2 BPE
leading_space = (
" " if sentence[: pronoun_span.start].text_with_ws.endswith(" ") else ""
)
trailing_space = " " if pronoun_span.text_with_ws.endswith(" ") else ""
# get noun phrases, excluding pronouns and anything overlapping with the query
cand_spans = wsc_utils.filter_noun_chunks(
wsc_utils.extended_noun_chunks(sentence),
exclude_pronouns=True,
exclude_query=query,
exact_match=False,
)
if query is not None:
query_toks, query_mask = self.binarize_with_mask(
query, prefix, suffix, leading_space, trailing_space
)
query_len = len(query_toks)
else:
query_toks, query_mask, query_len = None, None, 0
query_tokens.append(query_toks)
query_masks.append(query_mask)
query_lengths.append(query_len)
cand_toks, cand_masks = [], []
for cand_span in cand_spans:
toks, mask = self.binarize_with_mask(
cand_span.text,
prefix,
suffix,
leading_space,
trailing_space,
)
cand_toks.append(toks)
cand_masks.append(mask)
# collate candidates
cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad())
cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0)
assert cand_toks.size() == cand_masks.size()
candidate_tokens.append(cand_toks)
candidate_masks.append(cand_masks)
candidate_lengths.append(cand_toks.size(1))
labels.append(label)
query_lengths = np.array(query_lengths)
query_tokens = ListDataset(query_tokens, query_lengths)
query_masks = ListDataset(query_masks, query_lengths)
candidate_lengths = np.array(candidate_lengths)
candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
candidate_masks = ListDataset(candidate_masks, candidate_lengths)
labels = ListDataset(labels, [1] * len(labels))
dataset = {
"id": IdDataset(),
"query_tokens": query_tokens,
"query_masks": query_masks,
"candidate_tokens": candidate_tokens,
"candidate_masks": candidate_masks,
"labels": labels,
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(query_tokens, reduce=True),
}
nested_dataset = NestedDictionaryDataset(
dataset,
sizes=[query_lengths],
)
with data_utils.numpy_seed(self.args.seed):
shuffle = np.random.permutation(len(query_tokens))
dataset = SortDataset(
nested_dataset,
# shuffle
sort_order=[shuffle],
)
if return_only:
return dataset
self.datasets[split] = dataset
return self.datasets[split]
def build_dataset_for_inference(self, sample_json):
with tempfile.NamedTemporaryFile(buffering=0) as h:
h.write((json.dumps(sample_json) + "\n").encode("utf-8"))
dataset = self.load_dataset(
"disambiguate_pronoun",
data_path=h.name,
return_only=True,
)
return dataset
def disambiguate_pronoun(self, model, sentence, use_cuda=False):
sample_json = wsc_utils.convert_sentence_to_json(sentence)
dataset = self.build_dataset_for_inference(sample_json)
sample = dataset.collater([dataset[0]])
if use_cuda:
sample = utils.move_to_cuda(sample)
def get_masked_input(tokens, mask):
masked_tokens = tokens.clone()
masked_tokens[mask.bool()] = self.mask
return masked_tokens
def get_lprobs(tokens, mask):
logits, _ = model(src_tokens=get_masked_input(tokens, mask))
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
mask = mask.type_as(scores)
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
return scores
cand_lprobs = get_lprobs(
sample["candidate_tokens"][0],
sample["candidate_masks"][0],
)
if sample["query_tokens"][0] is not None:
query_lprobs = get_lprobs(
sample["query_tokens"][0].unsqueeze(0),
sample["query_masks"][0].unsqueeze(0),
)
return (query_lprobs >= cand_lprobs).all().item() == 1
else:
best_idx = cand_lprobs.argmax().item()
full_cand = sample["candidate_tokens"][0][best_idx]
mask = sample["candidate_masks"][0][best_idx]
toks = full_cand[mask.bool()]
return self.bpe.decode(self.source_dictionary.string(toks)).strip()
@property
def source_dictionary(self):
return self.vocab
@property
def target_dictionary(self):
return self.vocab
@register_task("winogrande")
class WinograndeTask(WSCTask):
"""
Task for WinoGrande dataset. Efficient implementation for Winograd schema
tasks with exactly two candidates, one of which is correct.
"""
@classmethod
def setup_task(cls, args, **kwargs):
assert args.criterion == "winogrande", "Must set --criterion=winogrande"
# load data and label dictionaries
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
print("| dictionary: {} types".format(len(vocab)))
return cls(args, vocab)
def load_dataset(
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if data_path is None:
data_path = os.path.join(self.args.data, split + ".jsonl")
if not os.path.exists(data_path):
raise FileNotFoundError("Cannot find data: {}".format(data_path))
query_tokens = []
query_masks = []
query_lengths = []
candidate_tokens = []
candidate_masks = []
candidate_lengths = []
itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == "test"))
for sample in itr:
sentence, pronoun_span, query, cand_text = sample
prefix = sentence[: pronoun_span[0]].rstrip()
suffix = sentence[pronoun_span[1] :]
leading_space = " " if sentence[: pronoun_span[0]].endswith(" ") else ""
trailing_space = ""
if query is not None:
query_toks, query_mask = self.binarize_with_mask(
query,
prefix,
suffix,
leading_space,
trailing_space,
)
query_len = len(query_toks)
else:
query_toks, query_mask, query_len = None, None, 0
query_tokens.append(query_toks)
query_masks.append(query_mask)
query_lengths.append(query_len)
cand_toks, cand_mask = self.binarize_with_mask(
cand_text,
prefix,
suffix,
leading_space,
trailing_space,
)
candidate_tokens.append(cand_toks)
candidate_masks.append(cand_mask)
candidate_lengths.append(cand_toks.size(0))
query_lengths = np.array(query_lengths)
def get_pad_dataset_fn(tokens, length, pad_idx):
return PadDataset(
ListDataset(tokens, length),
pad_idx=pad_idx,
left_pad=False,
)
query_tokens = get_pad_dataset_fn(query_tokens, query_lengths, self.vocab.pad())
query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0)
candidate_lengths = np.array(candidate_lengths)
candidate_tokens = get_pad_dataset_fn(
candidate_tokens, candidate_lengths, self.vocab.pad()
)
candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0)
dataset = {
"id": IdDataset(),
"query_tokens": query_tokens,
"query_masks": query_masks,
"candidate_tokens": candidate_tokens,
"candidate_masks": candidate_masks,
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(query_tokens, reduce=True),
}
nested_dataset = NestedDictionaryDataset(
dataset,
sizes=[query_lengths],
)
with data_utils.numpy_seed(self.args.seed):
shuffle = np.random.permutation(len(query_tokens))
dataset = SortDataset(
nested_dataset,
# shuffle
sort_order=[shuffle],
)
if return_only:
return dataset
self.datasets[split] = dataset
return self.datasets[split]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
from functools import lru_cache
def convert_sentence_to_json(sentence):
if "_" in sentence:
prefix, rest = sentence.split("_", 1)
query, rest = rest.split("_", 1)
query_index = len(prefix.rstrip().split(" "))
else:
query, query_index = None, None
prefix, rest = sentence.split("[", 1)
pronoun, rest = rest.split("]", 1)
pronoun_index = len(prefix.rstrip().split(" "))
sentence = sentence.replace("_", "").replace("[", "").replace("]", "")
return {
"idx": 0,
"text": sentence,
"target": {
"span1_index": query_index,
"span1_text": query,
"span2_index": pronoun_index,
"span2_text": pronoun,
},
}
def extended_noun_chunks(sentence):
noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks}
np_start, cur_np = 0, "NONE"
for i, token in enumerate(sentence):
np_type = token.pos_ if token.pos_ in {"NOUN", "PROPN"} else "NONE"
if np_type != cur_np:
if cur_np != "NONE":
noun_chunks.add((np_start, i))
if np_type != "NONE":
np_start = i
cur_np = np_type
if cur_np != "NONE":
noun_chunks.add((np_start, len(sentence)))
return [sentence[s:e] for (s, e) in sorted(noun_chunks)]
def find_token(sentence, start_pos):
found_tok = None
for tok in sentence:
if tok.idx == start_pos:
found_tok = tok
break
return found_tok
def find_span(sentence, search_text, start=0):
search_text = search_text.lower()
for tok in sentence[start:]:
remainder = sentence[tok.i :].text.lower()
if remainder.startswith(search_text):
len_to_consume = len(search_text)
start_idx = tok.idx
for next_tok in sentence[tok.i :]:
end_idx = next_tok.idx + len(next_tok.text)
if end_idx - start_idx == len_to_consume:
span = sentence[tok.i : next_tok.i + 1]
return span
return None
@lru_cache(maxsize=1)
def get_detokenizer():
from sacremoses import MosesDetokenizer
detok = MosesDetokenizer(lang="en")
return detok
@lru_cache(maxsize=1)
def get_spacy_nlp():
import en_core_web_lg
nlp = en_core_web_lg.load()
return nlp
def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
detok = get_detokenizer()
nlp = get_spacy_nlp()
with open(input_fname) as fin:
for line in fin:
sample = json.loads(line.strip())
if positive_only and "label" in sample and not sample["label"]:
# only consider examples where the query is correct
continue
target = sample["target"]
# clean up the query
query = target["span1_text"]
if query is not None:
if "\n" in query:
continue
if query.endswith(".") or query.endswith(","):
query = query[:-1]
# split tokens
tokens = sample["text"].split(" ")
def strip_pronoun(x):
return x.rstrip('.,"')
# find the pronoun
pronoun_idx = target["span2_index"]
pronoun = strip_pronoun(target["span2_text"])
if strip_pronoun(tokens[pronoun_idx]) != pronoun:
# hack: sometimes the index is misaligned
if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun:
pronoun_idx += 1
else:
raise Exception("Misaligned pronoun!")
assert strip_pronoun(tokens[pronoun_idx]) == pronoun
# split tokens before and after the pronoun
before = tokens[:pronoun_idx]
after = tokens[pronoun_idx + 1 :]
# the GPT BPE attaches leading spaces to tokens, so we keep track
# of whether we need spaces before or after the pronoun
leading_space = " " if pronoun_idx > 0 else ""
trailing_space = " " if len(after) > 0 else ""
# detokenize
before = detok.detokenize(before, return_str=True)
pronoun = detok.detokenize([pronoun], return_str=True)
after = detok.detokenize(after, return_str=True)
# hack: when the pronoun ends in a period (or comma), move the
# punctuation to the "after" part
if pronoun.endswith(".") or pronoun.endswith(","):
after = pronoun[-1] + trailing_space + after
pronoun = pronoun[:-1]
# hack: when the "after" part begins with a comma or period, remove
# the trailing space
if after.startswith(".") or after.startswith(","):
trailing_space = ""
# parse sentence with spacy
sentence = nlp(before + leading_space + pronoun + trailing_space + after)
# find pronoun span
start = len(before + leading_space)
first_pronoun_tok = find_token(sentence, start_pos=start)
pronoun_span = find_span(sentence, pronoun, start=first_pronoun_tok.i)
assert pronoun_span.text == pronoun
if eval:
# convert to format where pronoun is surrounded by "[]" and
# query is surrounded by "_"
query_span = find_span(sentence, query)
query_with_ws = "_{}_{}".format(
query_span.text,
(" " if query_span.text_with_ws.endswith(" ") else ""),
)
pronoun_with_ws = "[{}]{}".format(
pronoun_span.text,
(" " if pronoun_span.text_with_ws.endswith(" ") else ""),
)
if query_span.start < pronoun_span.start:
first = (query_span, query_with_ws)
second = (pronoun_span, pronoun_with_ws)
else:
first = (pronoun_span, pronoun_with_ws)
second = (query_span, query_with_ws)
sentence = (
sentence[: first[0].start].text_with_ws
+ first[1]
+ sentence[first[0].end : second[0].start].text_with_ws
+ second[1]
+ sentence[second[0].end :].text
)
yield sentence, sample.get("label", None)
else:
yield sentence, pronoun_span, query, sample.get("label", None)
def winogrande_jsonl_iterator(input_fname, eval=False):
with open(input_fname) as fin:
for line in fin:
sample = json.loads(line.strip())
sentence, option1, option2 = (
sample["sentence"],
sample["option1"],
sample["option2"],
)
pronoun_span = (sentence.index("_"), sentence.index("_") + 1)
if eval:
query, cand = option1, option2
else:
query = option1 if sample["answer"] == "1" else option2
cand = option2 if sample["answer"] == "1" else option1
yield sentence, pronoun_span, query, cand
def filter_noun_chunks(
chunks, exclude_pronouns=False, exclude_query=None, exact_match=False
):
if exclude_pronouns:
chunks = [
np
for np in chunks
if (np.lemma_ != "-PRON-" and not all(tok.pos_ == "PRON" for tok in np))
]
if exclude_query is not None:
excl_txt = [exclude_query.lower()]
filtered_chunks = []
for chunk in chunks:
lower_chunk = chunk.text.lower()
found = False
for excl in excl_txt:
if (
not exact_match and (lower_chunk in excl or excl in lower_chunk)
) or lower_chunk == excl:
found = True
break
if not found:
filtered_chunks.append(chunk)
chunks = filtered_chunks
return chunks
[Better Fine-Tuning by Reducing Representational Collapse](https://arxiv.org/abs/2008.03156)
=====================
This repo contains the code to replicate all experiments from the _Better Fine-Tuning by Reducing Representational Collapse_ paper excluding the probing results.
The R3F sentence prediction criterion is registered as `sentence_prediction_r3f` while the label smoothing version of it is implemented as `label_smoothed_cross_entropy_r3f`. The R4F version of the sentence prediction criterion can be achieved by applying spectral norm to the classification head via the `--spectral-norm-classification-head` parameter.
## Hyper-parameters
Our methods introduce 3 new hyper-parameters; `--eps` which sets the standard deviation or range of the distribution we're sampling from, `--r3f-lambda` which controls the combining of logistic loss and noisy KL loss and `--noise-type` which controls which parametric distribution we use ('normal', 'uniform').
For example to run R3F on RTE from GLUE
```
TOTAL_NUM_UPDATES=3120
WARMUP_UPDATES=187
LR=1e-05
NUM_CLASSES=2
MAX_SENTENCES=8 # Batch size.
ROBERTA_PATH=/path/to/roberta/model.pt
CUDA_VISIBLE_DEVICES=0 fairseq-train RTE-bin \
--restore-file $ROBERTA_PATH \
--max-positions 512 \
--max-sentences $MAX_SENTENCES \
--max-tokens 4400 \
--task sentence_prediction \
--reset-optimizer --reset-dataloader --reset-meters \
--required-batch-size-multiple 1 \
--init-token 0 --separator-token 2 \
--arch roberta_large \
--criterion sentence_prediction_r3f \
--num-classes $NUM_CLASSES \
--dropout 0.1 --attention-dropout 0.1 \
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
--clip-norm 0.0 \
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
--max-epoch 10 \
--find-unused-parameters \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
--noise-type uniform --r3f-lambda 0.7 \
--user-dir examples/rxf/rxf_src
```
## Citation
```bibtex
@article{aghajanyan2020better,
title={Better Fine-Tuning by Reducing Representational Collapse},
author={Aghajanyan, Armen and Shrivastava, Akshat and Gupta, Anchit and Goyal, Naman and Zettlemoyer, Luke and Gupta, Sonal},
journal={arXiv preprint arXiv:2008.03156},
year={2020}
}
```
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import rxf_src # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import label_smoothed_cross_entropy_r3f, sentence_prediction_r3f # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
@register_criterion("label_smoothed_cross_entropy_r3f")
class LabelSmoothedCrossEntropyR3FCriterion(FairseqCriterion):
def __init__(
self, task, sentence_avg, label_smoothing, eps, r3f_lambda, noise_type
):
super().__init__(task)
self.sentence_avg = sentence_avg
self.label_smoothing = label_smoothing
self.eps = eps
self.r3f_lambda = r3f_lambda
self.noise_type = noise_type
if self.noise_type in {"normal"}:
self.noise_sampler = torch.distributions.normal.Normal(
loc=0.0, scale=self.eps
)
elif self.noise_type == "uniform":
self.noise_sampler = torch.distributions.uniform.Uniform(
low=-self.eps, high=self.eps
)
else:
raise Exception(f"unrecognized noise type {self.noise_type}")
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
# fmt: off
parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
help='epsilon for label smoothing, 0 means no label smoothing')
parser.add_argument('--eps', type=float, default=1e-5,
help='noise eps')
parser.add_argument('--r3f-lambda', type=float, default=1.0,
help='lambda for combining logistic loss and noisy KL loss')
parser.add_argument('--noise-type', type=str, default='normal',
choices=['normal', 'uniform'],
help='type of noises')
# fmt: on
def _get_symm_kl(self, noised_logits, input_logits):
return (
F.kl_div(
F.log_softmax(noised_logits, dim=-1, dtype=torch.float32),
F.softmax(input_logits, dim=-1, dtype=torch.float32),
None,
None,
"sum",
)
+ F.kl_div(
F.log_softmax(input_logits, dim=-1, dtype=torch.float32),
F.softmax(noised_logits, dim=-1, dtype=torch.float32),
None,
None,
"sum",
)
) / noised_logits.size(0)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
token_embeddings = model.encoder.embed_tokens(sample["net_input"]["src_tokens"])
input_logits, extra = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(
model, (input_logits, extra), sample, reduce=reduce
)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
if model.training:
noise = self.noise_sampler.sample(sample_shape=token_embeddings.shape).to(
token_embeddings
)
noised_embeddings = token_embeddings.clone() + noise
noised_logits, _ = model(
**sample["net_input"], token_embeddings=noised_embeddings
)
symm_kl = self._get_symm_kl(noised_logits, input_logits)
if model.training:
symm_kl = symm_kl * sample_size
loss = loss + self.r3f_lambda * symm_kl
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
if model.training:
logging_output.update(
symm_kl=utils.item(symm_kl.data) if reduce else symm_kl.data
)
return loss, sample_size, logging_output
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1, 1)
loss, nll_loss = label_smoothed_nll_loss(
lprobs,
target,
self.label_smoothing,
ignore_index=self.padding_idx,
reduce=reduce,
)
return loss, nll_loss
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
symm_kl_sum = sum(log.get("symm_kl", 0) for log in logging_outputs)
metrics.log_scalar("symm_kl", symm_kl_sum / sample_size, sample_size, round=3)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion("sentence_prediction_r3f")
class SentencePredictionR3F(FairseqCriterion):
def __init__(
self,
task,
eps,
r3f_lambda,
noise_type,
classification_head_name,
regression_target,
):
super().__init__(task)
self.eps = eps
self.r3f_lambda = r3f_lambda
self.noise_type = noise_type
self.classification_head_name = classification_head_name
self.regression_target = regression_target
if self.noise_type in {"normal"}:
self.noise_sampler = torch.distributions.normal.Normal(
loc=0.0, scale=self.eps
)
elif self.noise_type == "uniform":
self.noise_sampler = torch.distributions.uniform.Uniform(
low=-self.eps, high=self.eps
)
else:
raise Exception(f"unrecognized noise type {self.noise_type}")
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--eps', type=float, default=1e-5,
help='noise eps')
parser.add_argument('--r3f-lambda', type=float, default=1.0,
help='lambda for combining logistic loss and noisy KL loss')
parser.add_argument('--noise-type', type=str, default='uniform',
choices=['normal', 'uniform'],
help='type of noises for RXF methods')
parser.add_argument('--classification-head-name',
default='sentence_classification_head',
help='name of the classification head to use')
# fmt: on
def _get_symm_kl(self, noised_logits, input_logits):
return (
F.kl_div(
F.log_softmax(noised_logits, dim=-1, dtype=torch.float32),
F.softmax(input_logits, dim=-1, dtype=torch.float32),
None,
None,
"sum",
)
+ F.kl_div(
F.log_softmax(input_logits, dim=-1, dtype=torch.float32),
F.softmax(noised_logits, dim=-1, dtype=torch.float32),
None,
None,
"sum",
)
) / noised_logits.size(0)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model, "classification_heads")
and self.classification_head_name in model.classification_heads
), "model must provide sentence classification head for --criterion=sentence_prediction"
token_embeddings = model.encoder.sentence_encoder.embed_tokens(
sample["net_input"]["src_tokens"]
)
input_logits, _ = model(
**sample["net_input"],
features_only=True,
classification_head_name=self.classification_head_name,
token_embeddings=token_embeddings,
)
if model.training and self.noise_sampler:
noise = self.noise_sampler.sample(sample_shape=token_embeddings.shape).to(
token_embeddings
)
noised_embeddings = token_embeddings.detach().clone() + noise
noised_logits, _ = model(
**sample["net_input"],
features_only=True,
classification_head_name=self.classification_head_name,
token_embeddings=noised_embeddings,
)
symm_kl = self._get_symm_kl(noised_logits, input_logits)
else:
symm_kl = 0
targets = model.get_targets(sample, [input_logits]).view(-1)
sample_size = targets.numel()
if not self.regression_target:
loss = F.nll_loss(
F.log_softmax(input_logits, dim=-1, dtype=torch.float32),
targets,
reduction="sum",
)
if model.training:
symm_kl = symm_kl * sample_size
loss = loss + self.r3f_lambda * symm_kl
else:
logits = input_logits.squeeze().float()
targets = targets.float()
loss = F.mse_loss(logits, targets, reduction="sum")
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample_size,
"sample_size": sample_size,
}
if not self.regression_target:
preds = input_logits.max(dim=1)[1]
logging_output.update(ncorrect=(preds == targets).sum().item())
if model.training and self.noise_sampler:
logging_output.update(
symm_kl=utils.item(symm_kl.data) if reduce else symm_kl.data
)
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
symm_kl_sum = sum(log.get("symm_kl", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_output = {
"loss": loss_sum / sample_size / math.log(2),
"symm_kl": symm_kl_sum / sample_size,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
agg_output.update(accuracy=ncorrect / nsentences)
if sample_size != ntokens:
agg_output["nll_loss"] = loss_sum / ntokens / math.log(2)
return agg_output
# Scaling Neural Machine Translation (Ott et al., 2018)
This page includes instructions for reproducing results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187).
## Pre-trained models
Model | Description | Dataset | Download
---|---|---|---
`transformer.wmt14.en-fr` | Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
`transformer.wmt16.en-de` | Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
## Training a new model on WMT'16 En-De
First download the [preprocessed WMT'16 En-De data provided by Google](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8).
Then:
##### 1. Extract the WMT'16 En-De data
```bash
TEXT=wmt16_en_de_bpe32k
mkdir -p $TEXT
tar -xzvf wmt16_en_de.tar.gz -C $TEXT
```
##### 2. Preprocess the dataset with a joined dictionary
```bash
fairseq-preprocess \
--source-lang en --target-lang de \
--trainpref $TEXT/train.tok.clean.bpe.32000 \
--validpref $TEXT/newstest2013.tok.bpe.32000 \
--testpref $TEXT/newstest2014.tok.bpe.32000 \
--destdir data-bin/wmt16_en_de_bpe32k \
--nwordssrc 32768 --nwordstgt 32768 \
--joined-dictionary \
--workers 20
```
##### 3. Train a model
```bash
fairseq-train \
data-bin/wmt16_en_de_bpe32k \
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
--dropout 0.3 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--max-tokens 3584 \
--fp16
```
Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU or newer.
***IMPORTANT:*** You will get better performance by training with big batches and
increasing the learning rate. If you want to train the above model with big batches
(assuming your machine has 8 GPUs):
- add `--update-freq 16` to simulate training on 8x16=128 GPUs
- increase the learning rate; 0.001 works well for big batches
##### 4. Evaluate
Now we can evaluate our trained model.
Note that the original [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
paper used a couple tricks to achieve better BLEU scores. We use these same tricks in
the Scaling NMT paper, so it's important to apply them when reproducing our results.
First, use the [average_checkpoints.py](/scripts/average_checkpoints.py) script to
average the last few checkpoints. Averaging the last 5-10 checkpoints is usually
good, but you may need to adjust this depending on how long you've trained:
```bash
python scripts/average_checkpoints \
--inputs /path/to/checkpoints \
--num-epoch-checkpoints 10 \
--output checkpoint.avg10.pt
```
Next, generate translations using a beam width of 4 and length penalty of 0.6:
```bash
fairseq-generate \
data-bin/wmt16_en_de_bpe32k \
--path checkpoint.avg10.pt \
--beam 4 --lenpen 0.6 --remove-bpe > gen.out
```
Finally, we apply the ["compound splitting" script](/scripts/compound_split_bleu.sh) to
add spaces around dashes. For example "Café-Liebhaber" would become three tokens:
"Café - Liebhaber". This typically results in larger BLEU scores, but it is not
appropriate to compare these inflated scores to work which does not include this trick.
This trick was used in the [original AIAYN code](https://github.com/tensorflow/tensor2tensor/blob/fc9335c0203685cbbfe2b30c92db4352d8f60779/tensor2tensor/utils/get_ende_bleu.sh),
so we used it in the Scaling NMT paper as well. That said, it's strongly advised to
report [sacrebleu](https://github.com/mjpost/sacrebleu) scores instead.
To compute "compound split" tokenized BLEU (not recommended!):
```bash
bash scripts/compound_split_bleu.sh gen.out
# BLEU4 = 29.29, 60.3/35.0/22.8/15.3 (BP=1.000, ratio=1.004, syslen=64763, reflen=64496)
```
To compute detokenized BLEU with sacrebleu (preferred):
```bash
bash scripts/sacrebleu.sh wmt14/full en de gen.out
# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt14/full+tok.13a+version.1.4.3 = 28.6 59.3/34.3/22.1/14.9 (BP = 1.000 ratio = 1.016 hyp_len = 63666 ref_len = 62688)
```
## Citation
```bibtex
@inproceedings{ott2018scaling,
title = {Scaling Neural Machine Translation},
author = {Ott, Myle and Edunov, Sergey and Grangier, David and Auli, Michael},
booktitle = {Proceedings of the Third Conference on Machine Translation (WMT)},
year = 2018,
}
```
# Simultaneous Translation
Examples of simultaneous translation in fairseq
- [English-to-Japanese text-to-text wait-k model](docs/enja-waitk.md)
- [English-to-Germen text-to-text monotonic multihead attention model](docs/ende-mma.md)
- [English-to-Germen speech-to-text simultaneous translation model](../speech_to_text/docs/simulst_mustc_example.md)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import models # noqa
# Simultaneous Machine Translation
This directory contains the code for the paper [Monotonic Multihead Attention](https://openreview.net/forum?id=Hyg96gBKPS)
## Prepare Data
[Please follow the instructions to download and preprocess the WMT'15 En-De dataset.](https://github.com/pytorch/fairseq/tree/simulastsharedtask/examples/translation#prepare-wmt14en2desh)
Another example of training an English to Japanese model can be found [here](docs/enja.md)
## Training
- MMA-IL
```shell
fairseq-train \
data-bin/wmt15_en_de_32k \
--simul-type infinite_lookback \
--user-dir $FAIRSEQ/example/simultaneous_translation \
--mass-preservation \
--criterion latency_augmented_label_smoothed_cross_entropy \
--latency-weight-avg 0.1 \
--max-update 50000 \
--arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \
--optimizer adam --adam-betas '(0.9, 0.98)' \
--lr-scheduler 'inverse_sqrt' \
--warmup-init-lr 1e-7 --warmup-updates 4000 \
--lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\
--dropout 0.3 \
--label-smoothing 0.1\
--max-tokens 3584
```
- MMA-H
```shell
fairseq-train \
data-bin/wmt15_en_de_32k \
--simul-type hard_aligned \
--user-dir $FAIRSEQ/example/simultaneous_translation \
--mass-preservation \
--criterion latency_augmented_label_smoothed_cross_entropy \
--latency-weight-var 0.1 \
--max-update 50000 \
--arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \
--optimizer adam --adam-betas '(0.9, 0.98)' \
--lr-scheduler 'inverse_sqrt' \
--warmup-init-lr 1e-7 --warmup-updates 4000 \
--lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\
--dropout 0.3 \
--label-smoothing 0.1\
--max-tokens 3584
```
- wait-k
```shell
fairseq-train \
data-bin/wmt15_en_de_32k \
--simul-type wait-k \
--waitk-lagging 3 \
--user-dir $FAIRSEQ/example/simultaneous_translation \
--mass-preservation \
--criterion latency_augmented_label_smoothed_cross_entropy \
--max-update 50000 \
--arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \
--optimizer adam --adam-betas '(0.9, 0.98)' \
--lr-scheduler 'inverse_sqrt' \
--warmup-init-lr 1e-7 --warmup-updates 4000 \
--lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\
--dropout 0.3 \
--label-smoothing 0.1\
--max-tokens 3584
```
# An example of English to Japaneses Simultaneous Translation System
This is an example of training and evaluating a transformer *wait-k* English to Japanese simultaneous text-to-text translation model.
## Data Preparation
This section introduces the data preparation for training and evaluation.
If you only want to evaluate the model, please jump to [Inference & Evaluation](#inference-&-evaluation)
For illustration, we only use the following subsets of the available data from [WMT20 news translation task](http://www.statmt.org/wmt20/translation-task.html), which results in 7,815,391 sentence pairs.
- News Commentary v16
- Wiki Titles v3
- WikiMatrix V1
- Japanese-English Subtitle Corpus
- The Kyoto Free Translation Task Corpus
We use WMT20 development data as development set. Training `transformer_vaswani_wmt_en_de_big` model on such amount of data will result in 17.3 BLEU with greedy search and 19.7 with beam (10) search. Notice that a better performance can be achieved with the full WMT training data.
We use [sentencepiece](https://github.com/google/sentencepiece) toolkit to tokenize the data with a vocabulary size of 32000.
Additionally, we filtered out the sentences longer than 200 words after tokenization.
Assuming the tokenized text data is saved at `${DATA_DIR}`,
we prepare the data binary with the following command.
```bash
fairseq-preprocess \
--source-lang en --target-lang ja \
--trainpref ${DATA_DIR}/train \
--validpref ${DATA_DIR}/dev \
--testpref ${DATA_DIR}/test \
--destdir ${WMT20_ENJA_DATA_BIN} \
--nwordstgt 32000 --nwordssrc 32000 \
--workers 20
```
## Simultaneous Translation Model Training
To train a wait-k `(k=10)` model.
```bash
fairseq-train ${WMT20_ENJA_DATA_BIN} \
--save-dir ${SAVEDIR}
--simul-type waitk \
--waitk-lagging 10 \
--max-epoch 70 \
--arch transformer_monotonic_vaswani_wmt_en_de_big \
--optimizer adam \
--adam-betas '(0.9, 0.98)' \
--lr-scheduler inverse_sqrt \
--warmup-init-lr 1e-07 \
--warmup-updates 4000 \
--lr 0.0005 \
--stop-min-lr 1e-09 \
--clip-norm 10.0 \
--dropout 0.3 \
--weight-decay 0.0 \
--criterion label_smoothed_cross_entropy \
--label-smoothing 0.1 \
--max-tokens 3584
```
This command is for training on 8 GPUs. Equivalently, the model can be trained on one GPU with `--update-freq 8`.
## Inference & Evaluation
First of all, install [SimulEval](https://github.com/facebookresearch/SimulEval) for evaluation.
```bash
git clone https://github.com/facebookresearch/SimulEval.git
cd SimulEval
pip install -e .
```
The following command is for the evaluation.
Assuming the source and reference files are `${SRC_FILE}` and `${REF_FILE}`, the sentencepiece model file for English is saved at `${SRC_SPM_PATH}`
```bash
simuleval \
--source ${SRC_FILE} \
--target ${TGT_FILE} \
--data-bin ${WMT20_ENJA_DATA_BIN} \
--sacrebleu-tokenizer ja-mecab \
--eval-latency-unit char \
--no-space \
--src-splitter-type sentencepiecemodel \
--src-splitter-path ${SRC_SPM_PATH} \
--agent ${FAIRSEQ}/examples/simultaneous_translation/agents/simul_trans_text_agent_enja.py \
--model-path ${SAVE_DIR}/${CHECKPOINT_FILENAME} \
--output ${OUTPUT} \
--scores
```
The `--data-bin` should be the same in previous sections if you prepare the data from the scratch.
If only for evaluation, a prepared data directory can be found [here](https://dl.fbaipublicfiles.com/simultaneous_translation/wmt20_enja_medium_databin.tgz) and a pretrained checkpoint (wait-k=10 model) can be downloaded from [here](https://dl.fbaipublicfiles.com/simultaneous_translation/wmt20_enja_medium_wait10_ckpt.pt).
The output should look like this:
```bash
{
"Quality": {
"BLEU": 11.442253287568398
},
"Latency": {
"AL": 8.6587861866951,
"AP": 0.7863304776251316,
"DAL": 9.477850951194764
}
}
```
The latency is evaluated by characters (`--eval-latency-unit`) on the target side. The latency is evaluated with `sacrebleu` with `MeCab` tokenizer `--sacrebleu-tokenizer ja-mecab`. `--no-space` indicates that do not add space when merging the predicted words.
If `--output ${OUTPUT}` option is used, the detailed log and scores will be stored under the `${OUTPUT}` directory.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from fairseq import checkpoint_utils, tasks
import sentencepiece as spm
import torch
try:
from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS
from simuleval.agents import TextAgent
except ImportError:
print("Please install simuleval 'pip install simuleval'")
BOS_PREFIX = "\u2581"
class SimulTransTextAgentJA(TextAgent):
"""
Simultaneous Translation
Text agent for Japanese
"""
def __init__(self, args):
# Whether use gpu
self.gpu = getattr(args, "gpu", False)
# Max len
self.max_len = args.max_len
# Load Model
self.load_model_vocab(args)
# build word splitter
self.build_word_splitter(args)
self.eos = DEFAULT_EOS
def initialize_states(self, states):
states.incremental_states = dict()
states.incremental_states["online"] = dict()
def to_device(self, tensor):
if self.gpu:
return tensor.cuda()
else:
return tensor.cpu()
def load_model_vocab(self, args):
filename = args.model_path
if not os.path.exists(filename):
raise IOError("Model file not found: {}".format(filename))
state = checkpoint_utils.load_checkpoint_to_cpu(filename)
task_args = state["cfg"]["task"]
task_args.data = args.data_bin
task = tasks.setup_task(task_args)
# build model for ensemble
state["cfg"]["model"].load_pretrained_encoder_from = None
state["cfg"]["model"].load_pretrained_decoder_from = None
self.model = task.build_model(state["cfg"]["model"])
self.model.load_state_dict(state["model"], strict=True)
self.model.eval()
self.model.share_memory()
if self.gpu:
self.model.cuda()
# Set dictionary
self.dict = {}
self.dict["tgt"] = task.target_dictionary
self.dict["src"] = task.source_dictionary
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--model-path', type=str, required=True,
help='path to your pretrained model.')
parser.add_argument("--data-bin", type=str, required=True,
help="Path of data binary")
parser.add_argument("--max-len", type=int, default=100,
help="Max length of translation")
parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece",
help="Subword splitter type for target text.")
parser.add_argument("--tgt-splitter-path", type=str, default=None,
help="Subword splitter model path for target text.")
parser.add_argument("--src-splitter-type", type=str, default="SentencePiece",
help="Subword splitter type for source text.")
parser.add_argument("--src-splitter-path", type=str, default=None,
help="Subword splitter model path for source text.")
# fmt: on
return parser
def build_word_splitter(self, args):
self.spm = {}
for lang in ['src', 'tgt']:
if getattr(args, f'{lang}_splitter_type', None):
path = getattr(args, f'{lang}_splitter_path', None)
if path:
self.spm[lang] = spm.SentencePieceProcessor()
self.spm[lang].Load(path)
def segment_to_units(self, segment, states):
# Split a full word (segment) into subwords (units)
return self.spm['src'].EncodeAsPieces(segment)
def update_model_encoder(self, states):
if len(states.units.source) == 0:
return
src_indices = [
self.dict['src'].index(x)
for x in states.units.source.value
]
if states.finish_read():
# Append the eos index when the prediction is over
src_indices += [self.dict["tgt"].eos_index]
src_indices = self.to_device(
torch.LongTensor(src_indices).unsqueeze(0)
)
src_lengths = self.to_device(
torch.LongTensor([src_indices.size(1)])
)
states.encoder_states = self.model.encoder(src_indices, src_lengths)
torch.cuda.empty_cache()
def update_states_read(self, states):
# Happens after a read action.
self.update_model_encoder(states)
def units_to_segment(self, units, states):
# Merge sub words (units) to full word (segment).
# For Japanese, we can directly send
# the untokenized token to server except the BOS token
# with following option
# --sacrebleu-tokenizer MeCab
# --eval-latency-unit char
# --no-space
token = units.value.pop()
if (
token == self.dict["tgt"].eos_word
or len(states.segments.target) > self.max_len
):
return DEFAULT_EOS
if BOS_PREFIX == token:
return None
if token[0] == BOS_PREFIX:
return token[1:]
else:
return token
def policy(self, states):
if not getattr(states, "encoder_states", None):
# No encoder states, read a token first
return READ_ACTION
# encode previous predicted target tokens
tgt_indices = self.to_device(
torch.LongTensor(
[self.model.decoder.dictionary.eos()]
+ [
self.dict['tgt'].index(x)
for x in states.units.target.value
if x is not None
]
).unsqueeze(0)
)
# Current steps
states.incremental_states["steps"] = {
"src": states.encoder_states["encoder_out"][0].size(0),
"tgt": 1 + len(states.units.target),
}
# Online only means the reading is not finished
states.incremental_states["online"]["only"] = (
torch.BoolTensor([not states.finish_read()])
)
x, outputs = self.model.decoder.forward(
prev_output_tokens=tgt_indices,
encoder_out=states.encoder_states,
incremental_state=states.incremental_states,
)
states.decoder_out = x
torch.cuda.empty_cache()
if outputs.action == 0:
return READ_ACTION
else:
return WRITE_ACTION
def predict(self, states):
# Predict target token from decoder states
decoder_states = states.decoder_out
lprobs = self.model.get_normalized_probs(
[decoder_states[:, -1:]], log_probs=True
)
index = lprobs.argmax(dim=-1)[0, 0].item()
if index != self.dict['tgt'].eos_index:
token = self.dict['tgt'].string([index])
else:
token = self.dict['tgt'].eos_word
return token
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
model_name = file[: file.find(".py")]
importlib.import_module(
"examples.simultaneous_translation.models." + model_name
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment