Commit 799a38c5 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #616 failed with stages
in 0 seconds
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
@register_criterion("label_smoothed_cross_entropy_r3f")
class LabelSmoothedCrossEntropyR3FCriterion(FairseqCriterion):
def __init__(
self, task, sentence_avg, label_smoothing, eps, r3f_lambda, noise_type
):
super().__init__(task)
self.sentence_avg = sentence_avg
self.label_smoothing = label_smoothing
self.eps = eps
self.r3f_lambda = r3f_lambda
self.noise_type = noise_type
if self.noise_type in {"normal"}:
self.noise_sampler = torch.distributions.normal.Normal(
loc=0.0, scale=self.eps
)
elif self.noise_type == "uniform":
self.noise_sampler = torch.distributions.uniform.Uniform(
low=-self.eps, high=self.eps
)
else:
raise Exception(f"unrecognized noise type {self.noise_type}")
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
# fmt: off
parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
help='epsilon for label smoothing, 0 means no label smoothing')
parser.add_argument('--eps', type=float, default=1e-5,
help='noise eps')
parser.add_argument('--r3f-lambda', type=float, default=1.0,
help='lambda for combining logistic loss and noisy KL loss')
parser.add_argument('--noise-type', type=str, default='normal',
choices=['normal', 'uniform'],
help='type of noises')
# fmt: on
def _get_symm_kl(self, noised_logits, input_logits):
return (
F.kl_div(
F.log_softmax(noised_logits, dim=-1, dtype=torch.float32),
F.softmax(input_logits, dim=-1, dtype=torch.float32),
None,
None,
"sum",
)
+ F.kl_div(
F.log_softmax(input_logits, dim=-1, dtype=torch.float32),
F.softmax(noised_logits, dim=-1, dtype=torch.float32),
None,
None,
"sum",
)
) / noised_logits.size(0)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
token_embeddings = model.encoder.embed_tokens(sample["net_input"]["src_tokens"])
input_logits, extra = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(
model, (input_logits, extra), sample, reduce=reduce
)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
if model.training:
noise = self.noise_sampler.sample(sample_shape=token_embeddings.shape).to(
token_embeddings
)
noised_embeddings = token_embeddings.clone() + noise
noised_logits, _ = model(
**sample["net_input"], token_embeddings=noised_embeddings
)
symm_kl = self._get_symm_kl(noised_logits, input_logits)
if model.training:
symm_kl = symm_kl * sample_size
loss = loss + self.r3f_lambda * symm_kl
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
if model.training:
logging_output.update(
symm_kl=utils.item(symm_kl.data) if reduce else symm_kl.data
)
return loss, sample_size, logging_output
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1, 1)
loss, nll_loss = label_smoothed_nll_loss(
lprobs,
target,
self.label_smoothing,
ignore_index=self.padding_idx,
reduce=reduce,
)
return loss, nll_loss
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
symm_kl_sum = sum(log.get("symm_kl", 0) for log in logging_outputs)
metrics.log_scalar("symm_kl", symm_kl_sum / sample_size, sample_size, round=3)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion("sentence_prediction_r3f")
class SentencePredictionR3F(FairseqCriterion):
def __init__(
self,
task,
eps,
r3f_lambda,
noise_type,
classification_head_name,
regression_target,
):
super().__init__(task)
self.eps = eps
self.r3f_lambda = r3f_lambda
self.noise_type = noise_type
self.classification_head_name = classification_head_name
self.regression_target = regression_target
if self.noise_type in {"normal"}:
self.noise_sampler = torch.distributions.normal.Normal(
loc=0.0, scale=self.eps
)
elif self.noise_type == "uniform":
self.noise_sampler = torch.distributions.uniform.Uniform(
low=-self.eps, high=self.eps
)
else:
raise Exception(f"unrecognized noise type {self.noise_type}")
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--eps', type=float, default=1e-5,
help='noise eps')
parser.add_argument('--r3f-lambda', type=float, default=1.0,
help='lambda for combining logistic loss and noisy KL loss')
parser.add_argument('--noise-type', type=str, default='uniform',
choices=['normal', 'uniform'],
help='type of noises for RXF methods')
parser.add_argument('--classification-head-name',
default='sentence_classification_head',
help='name of the classification head to use')
parser.add_argument('--regression-target', action='store_true')
# fmt: on
def _get_symm_kl(self, noised_logits, input_logits):
return (
F.kl_div(
F.log_softmax(noised_logits, dim=-1, dtype=torch.float32),
F.softmax(input_logits, dim=-1, dtype=torch.float32),
None,
None,
"sum",
)
+ F.kl_div(
F.log_softmax(input_logits, dim=-1, dtype=torch.float32),
F.softmax(noised_logits, dim=-1, dtype=torch.float32),
None,
None,
"sum",
)
) / noised_logits.size(0)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model, "classification_heads")
and self.classification_head_name in model.classification_heads
), "model must provide sentence classification head for --criterion=sentence_prediction"
token_embeddings = model.encoder.sentence_encoder.embed_tokens(
sample["net_input"]["src_tokens"]
)
input_logits, _ = model(
**sample["net_input"],
features_only=True,
classification_head_name=self.classification_head_name,
token_embeddings=token_embeddings,
)
if model.training and self.noise_sampler:
noise = self.noise_sampler.sample(sample_shape=token_embeddings.shape).to(
token_embeddings
)
noised_embeddings = token_embeddings.detach().clone() + noise
noised_logits, _ = model(
**sample["net_input"],
features_only=True,
classification_head_name=self.classification_head_name,
token_embeddings=noised_embeddings,
)
symm_kl = self._get_symm_kl(noised_logits, input_logits)
else:
symm_kl = 0
targets = model.get_targets(sample, [input_logits]).view(-1)
sample_size = targets.numel()
if not self.regression_target:
loss = F.nll_loss(
F.log_softmax(input_logits, dim=-1, dtype=torch.float32),
targets,
reduction="sum",
)
if model.training:
symm_kl = symm_kl * sample_size
loss = loss + self.r3f_lambda * symm_kl
else:
logits = input_logits.squeeze().float()
targets = targets.float()
loss = F.mse_loss(logits, targets, reduction="sum")
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample_size,
"sample_size": sample_size,
}
if not self.regression_target:
preds = input_logits.max(dim=1)[1]
logging_output.update(ncorrect=(preds == targets).sum().item())
if model.training and self.noise_sampler:
logging_output.update(
symm_kl=utils.item(symm_kl.data) if reduce else symm_kl.data
)
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
symm_kl_sum = sum(log.get("symm_kl", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_output = {
"loss": loss_sum / sample_size / math.log(2),
"symm_kl": symm_kl_sum / sample_size,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
agg_output.update(accuracy=ncorrect / nsentences)
if sample_size != ntokens:
agg_output["nll_loss"] = loss_sum / ntokens / math.log(2)
return agg_output
# Scaling Neural Machine Translation (Ott et al., 2018)
This page includes instructions for reproducing results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187).
## Pre-trained models
Model | Description | Dataset | Download
---|---|---|---
`transformer.wmt14.en-fr` | Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
`transformer.wmt16.en-de` | Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
## Training a new model on WMT'16 En-De
First download the [preprocessed WMT'16 En-De data provided by Google](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8).
Then:
##### 1. Extract the WMT'16 En-De data
```bash
TEXT=wmt16_en_de_bpe32k
mkdir -p $TEXT
tar -xzvf wmt16_en_de.tar.gz -C $TEXT
```
##### 2. Preprocess the dataset with a joined dictionary
```bash
fairseq-preprocess \
--source-lang en --target-lang de \
--trainpref $TEXT/train.tok.clean.bpe.32000 \
--validpref $TEXT/newstest2013.tok.bpe.32000 \
--testpref $TEXT/newstest2014.tok.bpe.32000 \
--destdir data-bin/wmt16_en_de_bpe32k \
--nwordssrc 32768 --nwordstgt 32768 \
--joined-dictionary \
--workers 20
```
##### 3. Train a model
```bash
fairseq-train \
data-bin/wmt16_en_de_bpe32k \
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
--dropout 0.3 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--max-tokens 3584 \
--fp16
```
Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU or newer.
***IMPORTANT:*** You will get better performance by training with big batches and
increasing the learning rate. If you want to train the above model with big batches
(assuming your machine has 8 GPUs):
- add `--update-freq 16` to simulate training on 8x16=128 GPUs
- increase the learning rate; 0.001 works well for big batches
##### 4. Evaluate
Now we can evaluate our trained model.
Note that the original [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
paper used a couple tricks to achieve better BLEU scores. We use these same tricks in
the Scaling NMT paper, so it's important to apply them when reproducing our results.
First, use the [average_checkpoints.py](/scripts/average_checkpoints.py) script to
average the last few checkpoints. Averaging the last 5-10 checkpoints is usually
good, but you may need to adjust this depending on how long you've trained:
```bash
python scripts/average_checkpoints \
--inputs /path/to/checkpoints \
--num-epoch-checkpoints 10 \
--output checkpoint.avg10.pt
```
Next, generate translations using a beam width of 4 and length penalty of 0.6:
```bash
fairseq-generate \
data-bin/wmt16_en_de_bpe32k \
--path checkpoint.avg10.pt \
--beam 4 --lenpen 0.6 --remove-bpe > gen.out
```
Finally, we apply the ["compound splitting" script](/scripts/compound_split_bleu.sh) to
add spaces around dashes. For example "Café-Liebhaber" would become three tokens:
"Café - Liebhaber". This typically results in larger BLEU scores, but it is not
appropriate to compare these inflated scores to work which does not include this trick.
This trick was used in the [original AIAYN code](https://github.com/tensorflow/tensor2tensor/blob/fc9335c0203685cbbfe2b30c92db4352d8f60779/tensor2tensor/utils/get_ende_bleu.sh),
so we used it in the Scaling NMT paper as well. That said, it's strongly advised to
report [sacrebleu](https://github.com/mjpost/sacrebleu) scores instead.
To compute "compound split" tokenized BLEU (not recommended!):
```bash
bash scripts/compound_split_bleu.sh gen.out
# BLEU4 = 29.29, 60.3/35.0/22.8/15.3 (BP=1.000, ratio=1.004, syslen=64763, reflen=64496)
```
To compute detokenized BLEU with sacrebleu (preferred):
```bash
bash scripts/sacrebleu.sh wmt14/full en de gen.out
# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt14/full+tok.13a+version.1.4.3 = 28.6 59.3/34.3/22.1/14.9 (BP = 1.000 ratio = 1.016 hyp_len = 63666 ref_len = 62688)
```
## Citation
```bibtex
@inproceedings{ott2018scaling,
title = {Scaling Neural Machine Translation},
author = {Ott, Myle and Edunov, Sergey and Grangier, David and Auli, Michael},
booktitle = {Proceedings of the Third Conference on Machine Translation (WMT)},
year = 2018,
}
```
# Fine-tuning details
For each task (GLUE and PAWS), we perform hyperparam search for each model, and report the mean and standard deviation across 5 seeds of the best model. First, get the datasets following the instructions in [RoBERTa fine-tuning README](../roberta/README.glue.md). Alternatively, you can use [huggingface datasets](https://huggingface.co/docs/datasets/) to get the task data:
```python
from datasets import load_dataset
import pandas as pd
from pathlib import Path
key2file = {
"paws": {
"loc": "paws_data",
"columns": ["id", "sentence1", "sentence2", "label"],
"train": "train.tsv",
"validation": "dev.tsv",
"test": "test.tsv"
}
}
task_data = load_dataset("paws", "labeled_final")
task_config = key2file["paws"]
save_path = Path(task_config["loc"])
save_path.mkdir(exist_ok=True, parents=True)
for key, fl in task_config.items():
if key in ["loc", "columns"]:
continue
print(f"Reading {key}")
columns = task_config["columns"]
df = pd.DataFrame(task_data[key])
print(df.columns)
df = df[columns]
print(f"Got {len(df)} records")
save_loc = save_path / fl
print(f"Saving to : {save_loc}")
df.to_csv(save_loc, sep="\t", header=None, index=None)
```
- Preprocess using RoBERTa GLUE preprocessing script, while keeping in mind the column numbers for `sentence1`, `sentence2` and `label` (which is 0,1,2 if you save the data according to the above example.)
- Then, fine-tuning is performed similarly to RoBERTa (for example, in case of RTE):
```bash
TOTAL_NUM_UPDATES=30875 # 10 epochs through RTE for bsz 16
WARMUP_UPDATES=1852 # 6 percent of the number of updates
LR=2e-05 # Peak LR for polynomial LR scheduler.
NUM_CLASSES=2
MAX_SENTENCES=16 # Batch size.
SHUFFLED_ROBERTA_PATH=/path/to/shuffled_roberta/model.pt
CUDA_VISIBLE_DEVICES=0 fairseq-train RTE-bin/ \
--restore-file $SHUFFLED_ROBERTA_PATH \
--max-positions 512 \
--batch-size $MAX_SENTENCES \
--max-tokens 4400 \
--task sentence_prediction \
--reset-optimizer --reset-dataloader --reset-meters \
--required-batch-size-multiple 1 \
--init-token 0 --separator-token 2 \
--arch roberta_large \
--criterion sentence_prediction \
--num-classes $NUM_CLASSES \
--dropout 0.1 --attention-dropout 0.1 \
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
--clip-norm 0.0 \
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
--max-epoch 10 \
--find-unused-parameters \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
```
- `TOTAL_NUM_UPDATES` is computed based on the `--batch_size` value and the dataset size.
- `WARMUP_UPDATES` is computed as 6% of `TOTAL_NUM_UPDATES`
- Best hyperparam of `--lr` and `--batch_size` is reported below:
## `--lr`
| | name | RTE | MRPC | SST-2 | CoLA | QQP | QNLI | MNLI | PAWS |
| --: | :----------- | ----: | ----: | ----: | ----: | ----: | ----: | ----: | ----: |
| 0 | original | 2e-05 | 2e-05 | 1e-05 | 2e-05 | 1e-05 | 1e-05 | 1e-05 | 2e-05 |
| 1 | n_1 | 2e-05 | 1e-05 | 1e-05 | 1e-05 | 3e-05 | 1e-05 | 2e-05 | 2e-05 |
| 2 | n_2 | 2e-05 | 2e-05 | 1e-05 | 1e-05 | 2e-05 | 1e-05 | 1e-05 | 3e-05 |
| 3 | n_3 | 3e-05 | 1e-05 | 2e-05 | 2e-05 | 3e-05 | 1e-05 | 1e-05 | 2e-05 |
| 4 | n_4 | 3e-05 | 1e-05 | 2e-05 | 2e-05 | 2e-05 | 1e-05 | 1e-05 | 2e-05 |
| 5 | r512 | 1e-05 | 3e-05 | 2e-05 | 2e-05 | 3e-05 | 2e-05 | 3e-05 | 2e-05 |
| 6 | rand_corpus | 2e-05 | 1e-05 | 3e-05 | 1e-05 | 3e-05 | 3e-05 | 3e-05 | 2e-05 |
| 7 | rand_uniform | 2e-05 | 1e-05 | 3e-05 | 2e-05 | 3e-05 | 3e-05 | 3e-05 | 1e-05 |
| 8 | rand_init | 1e-05 | 1e-05 | 3e-05 | 1e-05 | 1e-05 | 1e-05 | 2e-05 | 1e-05 |
| 9 | no_pos | 1e-05 | 3e-05 | 2e-05 | 1e-05 | 1e-05 | 1e-05 | 1e-05 | 1e-05 |
## `--batch_size`
| | name | RTE | MRPC | SST-2 | CoLA | QQP | QNLI | MNLI | PAWS |
| --: | :----------- | --: | ---: | ----: | ---: | --: | ---: | ---: | ---: |
| 0 | orig | 16 | 16 | 32 | 16 | 16 | 32 | 32 | 16 |
| 1 | n_1 | 32 | 32 | 16 | 32 | 32 | 16 | 32 | 16 |
| 2 | n_2 | 32 | 16 | 32 | 16 | 32 | 32 | 16 | 32 |
| 3 | n_3 | 32 | 32 | 16 | 32 | 32 | 16 | 32 | 32 |
| 4 | n_4 | 32 | 16 | 32 | 16 | 32 | 32 | 32 | 32 |
| 5 | r512 | 32 | 16 | 16 | 32 | 32 | 16 | 16 | 16 |
| 6 | rand_corpus | 16 | 16 | 16 | 16 | 32 | 16 | 16 | 32 |
| 7 | rand_uniform | 16 | 32 | 16 | 16 | 32 | 16 | 16 | 16 |
| 8 | rand_init | 16 | 16 | 32 | 16 | 16 | 16 | 32 | 16 |
| 9 | no_pos | 16 | 32 | 16 | 16 | 32 | 16 | 16 | 16 |
- Perform inference similar to RoBERTa as well:
```python
from fairseq.models.roberta import RobertaModel
roberta = RobertaModel.from_pretrained(
'checkpoints/',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='PAWS-bin'
)
label_fn = lambda label: roberta.task.label_dictionary.string(
[label + roberta.task.label_dictionary.nspecial]
)
ncorrect, nsamples = 0, 0
roberta.cuda()
roberta.eval()
with open('paws_data/dev.tsv') as fin:
fin.readline()
for index, line in enumerate(fin):
tokens = line.strip().split('\t')
sent1, sent2, target = tokens[0], tokens[1], tokens[2]
tokens = roberta.encode(sent1, sent2)
prediction = roberta.predict('sentence_classification_head', tokens).argmax().item()
prediction_label = label_fn(prediction)
ncorrect += int(prediction_label == target)
nsamples += 1
print('| Accuracy: ', float(ncorrect)/float(nsamples))
```
# Masked Language Modeling and the Distributional Hypothesis: Order Word Matters Pre-training for Little
[https://arxiv.org/abs/2104.06644](https://arxiv.org/abs/2104.06644)
## Introduction
In this work, we pre-train [RoBERTa](../roberta) base on various word shuffled variants of BookWiki corpus (16GB). We observe that a word shuffled pre-trained model achieves surprisingly good scores on GLUE, PAWS and several parametric probing tasks. Please read our paper for more details on the experiments.
## Pre-trained models
| Model | Description | Download |
| ------------------------------------- | -------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
| `roberta.base.orig` | RoBERTa (base) trained on natural corpus | [roberta.base.orig.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.orig.tar.gz) |
| `roberta.base.shuffle.n1` | RoBERTa (base) trained on n=1 gram sentence word shuffled data | [roberta.base.shuffle.n1.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n1.tar.gz) |
| `roberta.base.shuffle.n2` | RoBERTa (base) trained on n=2 gram sentence word shuffled data | [roberta.base.shuffle.n2.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n2.tar.gz) |
| `roberta.base.shuffle.n3` | RoBERTa (base) trained on n=3 gram sentence word shuffled data | [roberta.base.shuffle.n3.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n3.tar.gz) |
| `roberta.base.shuffle.n4` | RoBERTa (base) trained on n=4 gram sentence word shuffled data | [roberta.base.shuffle.n4.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n4.tar.gz) |
| `roberta.base.shuffle.512` | RoBERTa (base) trained on unigram 512 word block shuffled data | [roberta.base.shuffle.512.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.512.tar.gz) |
| `roberta.base.shuffle.corpus` | RoBERTa (base) trained on unigram corpus word shuffled data | [roberta.base.shuffle.corpus.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.corpus.tar.gz) |
| `roberta.base.shuffle.corpus_uniform` | RoBERTa (base) trained on unigram corpus word shuffled data, where all words are uniformly sampled | [roberta.base.shuffle.corpus_uniform.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.corpus_uniform.tar.gz) |
| `roberta.base.nopos` | RoBERTa (base) without positional embeddings, trained on natural corpus | [roberta.base.nopos.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.nopos.tar.gz) |
## Results
[GLUE (Wang et al, 2019)](https://gluebenchmark.com/) & [PAWS (Zhang et al, 2019)](https://github.com/google-research-datasets/paws) _(dev set, single model, single-task fine-tuning, median of 5 seeds)_
| name | CoLA | MNLI | MRPC | PAWS | QNLI | QQP | RTE | SST-2 |
| :----------------------------------- | ----: | ----: | ----: | ----: | ----: | ----: | ----: | ----: |
| `roberta.base.orig` | 61.4 | 86.11 | 89.19 | 94.46 | 92.53 | 91.26 | 74.64 | 93.92 |
| `roberta.base.shuffle.n1` | 35.15 | 82.64 | 86 | 89.97 | 89.02 | 91.01 | 69.02 | 90.47 |
| `roberta.base.shuffle.n2` | 54.37 | 83.43 | 86.24 | 93.46 | 90.44 | 91.36 | 70.83 | 91.79 |
| `roberta.base.shuffle.n3` | 48.72 | 83.85 | 86.36 | 94.05 | 91.69 | 91.24 | 70.65 | 92.02 |
| `roberta.base.shuffle.n4` | 58.64 | 83.77 | 86.98 | 94.32 | 91.69 | 91.4 | 70.83 | 92.48 |
| `roberta.base.shuffle.512` | 12.76 | 77.52 | 79.61 | 84.77 | 85.19 | 90.2 | 56.52 | 86.34 |
| `roberta.base.shuffle.corpus` | 0 | 71.9 | 70.52 | 58.52 | 71.11 | 85.52 | 53.99 | 83.35 |
| `roberta.base.shuffle.corpus_random` | 9.19 | 72.33 | 70.76 | 58.42 | 77.76 | 85.93 | 53.99 | 84.04 |
| `roberta.base.nopos` | 0 | 63.5 | 72.73 | 57.08 | 77.72 | 87.87 | 54.35 | 83.24 |
For more results on probing tasks, please refer to [our paper](https://arxiv.org/abs/2104.06644).
## Example Usage
Follow the same usage as in [RoBERTa](https://github.com/pytorch/fairseq/tree/main/examples/roberta) to load and test your models:
```python
# Download roberta.base.shuffle.n1 model
wget https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n1.tar.gz
tar -xzvf roberta.base.shuffle.n1.tar.gz
# Load the model in fairseq
from fairseq.models.roberta import RoBERTaModel
roberta = RoBERTaModel.from_pretrained('/path/to/roberta.base.shuffle.n1', checkpoint_file='model.pt')
roberta.eval() # disable dropout (or leave in train mode to finetune)
```
**Note**: The model trained without positional embeddings (`roberta.base.nopos`) is a modified `RoBERTa` model, where the positional embeddings are not used. Thus, the typical `from_pretrained` method on fairseq version of RoBERTa will not be able to load the above model weights. To do so, construct a new `RoBERTaModel` object by setting the flag `use_positional_embeddings` to `False` (or [in the latest code](https://github.com/pytorch/fairseq/blob/main/fairseq/models/roberta/model.py#L543), set `no_token_positional_embeddings` to `True`), and then load the individual weights.
## Fine-tuning Evaluation
We provide the trained fine-tuned models on MNLI here for each model above for quick evaluation (1 seed for each model). Please refer to [finetuning details](README.finetuning.md) for the parameters of these models. Follow [RoBERTa](https://github.com/pytorch/fairseq/tree/main/examples/roberta) instructions to evaluate these models.
| Model | MNLI M Dev Accuracy | Link |
| :----------------------------------------- | :------------------ | :--------------------------------------------------------------------------------------------------------------- |
| `roberta.base.orig.mnli` | 86.14 | [Download](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.orig.mnli.tar.gz) |
| `roberta.base.shuffle.n1.mnli` | 82.55 | [Download](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n1.mnli.tar.gz) |
| `roberta.base.shuffle.n2.mnli` | 83.21 | [Download](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n2.mnli.tar.gz) |
| `roberta.base.shuffle.n3.mnli` | 83.89 | [Download](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n3.mnli.tar.gz) |
| `roberta.base.shuffle.n4.mnli` | 84.00 | [Download](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n4.mnli.tar.gz) |
| `roberta.base.shuffle.512.mnli` | 77.22 | [Download](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.512.mnli.tar.gz) |
| `roberta.base.shuffle.corpus.mnli` | 71.88 | [Download](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.corpus.mnli.tar.gz) |
| `roberta.base.shuffle.corpus_uniform.mnli` | 72.46 | [Download](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.corpus_uniform.mnli.tar.gz) |
## Citation
```bibtex
@misc{sinha2021masked,
title={Masked Language Modeling and the Distributional Hypothesis: Order Word Matters Pre-training for Little},
author={Koustuv Sinha and Robin Jia and Dieuwke Hupkes and Joelle Pineau and Adina Williams and Douwe Kiela},
year={2021},
eprint={2104.06644},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
# Simultaneous Translation
Examples of simultaneous translation in fairseq
- [English-to-Japanese text-to-text wait-k model](docs/enja-waitk.md)
- [English-to-Germen text-to-text monotonic multihead attention model](docs/ende-mma.md)
- [English-to-Germen speech-to-text simultaneous translation model](../speech_to_text/docs/simulst_mustc_example.md)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import models # noqa
# Simultaneous Machine Translation
This directory contains the code for the paper [Monotonic Multihead Attention](https://openreview.net/forum?id=Hyg96gBKPS)
## Prepare Data
[Please follow the instructions to download and preprocess the WMT'15 En-De dataset.](https://github.com/pytorch/fairseq/tree/simulastsharedtask/examples/translation#prepare-wmt14en2desh)
Another example of training an English to Japanese model can be found [here](docs/enja.md)
## Training
- MMA-IL
```shell
fairseq-train \
data-bin/wmt15_en_de_32k \
--simul-type infinite_lookback \
--user-dir $FAIRSEQ/example/simultaneous_translation \
--mass-preservation \
--criterion latency_augmented_label_smoothed_cross_entropy \
--latency-weight-avg 0.1 \
--max-update 50000 \
--arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \
--optimizer adam --adam-betas '(0.9, 0.98)' \
--lr-scheduler 'inverse_sqrt' \
--warmup-init-lr 1e-7 --warmup-updates 4000 \
--lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\
--dropout 0.3 \
--label-smoothing 0.1\
--max-tokens 3584
```
- MMA-H
```shell
fairseq-train \
data-bin/wmt15_en_de_32k \
--simul-type hard_aligned \
--user-dir $FAIRSEQ/example/simultaneous_translation \
--mass-preservation \
--criterion latency_augmented_label_smoothed_cross_entropy \
--latency-weight-var 0.1 \
--max-update 50000 \
--arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \
--optimizer adam --adam-betas '(0.9, 0.98)' \
--lr-scheduler 'inverse_sqrt' \
--warmup-init-lr 1e-7 --warmup-updates 4000 \
--lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\
--dropout 0.3 \
--label-smoothing 0.1\
--max-tokens 3584
```
- wait-k
```shell
fairseq-train \
data-bin/wmt15_en_de_32k \
--simul-type wait-k \
--waitk-lagging 3 \
--user-dir $FAIRSEQ/example/simultaneous_translation \
--mass-preservation \
--criterion latency_augmented_label_smoothed_cross_entropy \
--max-update 50000 \
--arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \
--optimizer adam --adam-betas '(0.9, 0.98)' \
--lr-scheduler 'inverse_sqrt' \
--warmup-init-lr 1e-7 --warmup-updates 4000 \
--lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\
--dropout 0.3 \
--label-smoothing 0.1\
--max-tokens 3584
```
# An example of English to Japaneses Simultaneous Translation System
This is an example of training and evaluating a transformer *wait-k* English to Japanese simultaneous text-to-text translation model.
## Data Preparation
This section introduces the data preparation for training and evaluation.
If you only want to evaluate the model, please jump to [Inference & Evaluation](#inference-&-evaluation)
For illustration, we only use the following subsets of the available data from [WMT20 news translation task](http://www.statmt.org/wmt20/translation-task.html), which results in 7,815,391 sentence pairs.
- News Commentary v16
- Wiki Titles v3
- WikiMatrix V1
- Japanese-English Subtitle Corpus
- The Kyoto Free Translation Task Corpus
We use WMT20 development data as development set. Training `transformer_vaswani_wmt_en_de_big` model on such amount of data will result in 17.3 BLEU with greedy search and 19.7 with beam (10) search. Notice that a better performance can be achieved with the full WMT training data.
We use [sentencepiece](https://github.com/google/sentencepiece) toolkit to tokenize the data with a vocabulary size of 32000.
Additionally, we filtered out the sentences longer than 200 words after tokenization.
Assuming the tokenized text data is saved at `${DATA_DIR}`,
we prepare the data binary with the following command.
```bash
fairseq-preprocess \
--source-lang en --target-lang ja \
--trainpref ${DATA_DIR}/train \
--validpref ${DATA_DIR}/dev \
--testpref ${DATA_DIR}/test \
--destdir ${WMT20_ENJA_DATA_BIN} \
--nwordstgt 32000 --nwordssrc 32000 \
--workers 20
```
## Simultaneous Translation Model Training
To train a wait-k `(k=10)` model.
```bash
fairseq-train ${WMT20_ENJA_DATA_BIN} \
--save-dir ${SAVEDIR}
--simul-type waitk \
--waitk-lagging 10 \
--max-epoch 70 \
--arch transformer_monotonic_vaswani_wmt_en_de_big \
--optimizer adam \
--adam-betas '(0.9, 0.98)' \
--lr-scheduler inverse_sqrt \
--warmup-init-lr 1e-07 \
--warmup-updates 4000 \
--lr 0.0005 \
--stop-min-lr 1e-09 \
--clip-norm 10.0 \
--dropout 0.3 \
--weight-decay 0.0 \
--criterion label_smoothed_cross_entropy \
--label-smoothing 0.1 \
--max-tokens 3584
```
This command is for training on 8 GPUs. Equivalently, the model can be trained on one GPU with `--update-freq 8`.
## Inference & Evaluation
First of all, install [SimulEval](https://github.com/facebookresearch/SimulEval) for evaluation.
```bash
git clone https://github.com/facebookresearch/SimulEval.git
cd SimulEval
pip install -e .
```
The following command is for the evaluation.
Assuming the source and reference files are `${SRC_FILE}` and `${REF_FILE}`, the sentencepiece model file for English is saved at `${SRC_SPM_PATH}`
```bash
simuleval \
--source ${SRC_FILE} \
--target ${TGT_FILE} \
--data-bin ${WMT20_ENJA_DATA_BIN} \
--sacrebleu-tokenizer ja-mecab \
--eval-latency-unit char \
--no-space \
--src-splitter-type sentencepiecemodel \
--src-splitter-path ${SRC_SPM_PATH} \
--agent ${FAIRSEQ}/examples/simultaneous_translation/agents/simul_trans_text_agent_enja.py \
--model-path ${SAVE_DIR}/${CHECKPOINT_FILENAME} \
--output ${OUTPUT} \
--scores
```
The `--data-bin` should be the same in previous sections if you prepare the data from the scratch.
If only for evaluation, a prepared data directory can be found [here](https://dl.fbaipublicfiles.com/simultaneous_translation/wmt20_enja_medium_databin.tgz) and a pretrained checkpoint (wait-k=10 model) can be downloaded from [here](https://dl.fbaipublicfiles.com/simultaneous_translation/wmt20_enja_medium_wait10_ckpt.pt).
The output should look like this:
```bash
{
"Quality": {
"BLEU": 11.442253287568398
},
"Latency": {
"AL": 8.6587861866951,
"AP": 0.7863304776251316,
"DAL": 9.477850951194764
}
}
```
The latency is evaluated by characters (`--eval-latency-unit`) on the target side. The latency is evaluated with `sacrebleu` with `MeCab` tokenizer `--sacrebleu-tokenizer ja-mecab`. `--no-space` indicates that do not add space when merging the predicted words.
If `--output ${OUTPUT}` option is used, the detailed log and scores will be stored under the `${OUTPUT}` directory.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from fairseq import checkpoint_utils, tasks
import sentencepiece as spm
import torch
try:
from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS
from simuleval.agents import TextAgent
except ImportError:
print("Please install simuleval 'pip install simuleval'")
BOS_PREFIX = "\u2581"
class SimulTransTextAgentJA(TextAgent):
"""
Simultaneous Translation
Text agent for Japanese
"""
def __init__(self, args):
# Whether use gpu
self.gpu = getattr(args, "gpu", False)
# Max len
self.max_len = args.max_len
# Load Model
self.load_model_vocab(args)
# build word splitter
self.build_word_splitter(args)
self.eos = DEFAULT_EOS
def initialize_states(self, states):
states.incremental_states = dict()
states.incremental_states["online"] = dict()
def to_device(self, tensor):
if self.gpu:
return tensor.cuda()
else:
return tensor.cpu()
def load_model_vocab(self, args):
filename = args.model_path
if not os.path.exists(filename):
raise IOError("Model file not found: {}".format(filename))
state = checkpoint_utils.load_checkpoint_to_cpu(filename)
task_args = state["cfg"]["task"]
task_args.data = args.data_bin
task = tasks.setup_task(task_args)
# build model for ensemble
state["cfg"]["model"].load_pretrained_encoder_from = None
state["cfg"]["model"].load_pretrained_decoder_from = None
self.model = task.build_model(state["cfg"]["model"])
self.model.load_state_dict(state["model"], strict=True)
self.model.eval()
self.model.share_memory()
if self.gpu:
self.model.cuda()
# Set dictionary
self.dict = {}
self.dict["tgt"] = task.target_dictionary
self.dict["src"] = task.source_dictionary
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--model-path', type=str, required=True,
help='path to your pretrained model.')
parser.add_argument("--data-bin", type=str, required=True,
help="Path of data binary")
parser.add_argument("--max-len", type=int, default=100,
help="Max length of translation")
parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece",
help="Subword splitter type for target text.")
parser.add_argument("--tgt-splitter-path", type=str, default=None,
help="Subword splitter model path for target text.")
parser.add_argument("--src-splitter-type", type=str, default="SentencePiece",
help="Subword splitter type for source text.")
parser.add_argument("--src-splitter-path", type=str, default=None,
help="Subword splitter model path for source text.")
# fmt: on
return parser
def build_word_splitter(self, args):
self.spm = {}
for lang in ['src', 'tgt']:
if getattr(args, f'{lang}_splitter_type', None):
path = getattr(args, f'{lang}_splitter_path', None)
if path:
self.spm[lang] = spm.SentencePieceProcessor()
self.spm[lang].Load(path)
def segment_to_units(self, segment, states):
# Split a full word (segment) into subwords (units)
return self.spm['src'].EncodeAsPieces(segment)
def update_model_encoder(self, states):
if len(states.units.source) == 0:
return
src_indices = [
self.dict['src'].index(x)
for x in states.units.source.value
]
if states.finish_read():
# Append the eos index when the prediction is over
src_indices += [self.dict["tgt"].eos_index]
src_indices = self.to_device(
torch.LongTensor(src_indices).unsqueeze(0)
)
src_lengths = self.to_device(
torch.LongTensor([src_indices.size(1)])
)
states.encoder_states = self.model.encoder(src_indices, src_lengths)
torch.cuda.empty_cache()
def update_states_read(self, states):
# Happens after a read action.
self.update_model_encoder(states)
def units_to_segment(self, units, states):
# Merge sub words (units) to full word (segment).
# For Japanese, we can directly send
# the untokenized token to server except the BOS token
# with following option
# --sacrebleu-tokenizer MeCab
# --eval-latency-unit char
# --no-space
token = units.value.pop()
if (
token == self.dict["tgt"].eos_word
or len(states.segments.target) > self.max_len
):
return DEFAULT_EOS
if BOS_PREFIX == token:
return None
if token[0] == BOS_PREFIX:
return token[1:]
else:
return token
def policy(self, states):
if not getattr(states, "encoder_states", None):
# No encoder states, read a token first
return READ_ACTION
# encode previous predicted target tokens
tgt_indices = self.to_device(
torch.LongTensor(
[self.model.decoder.dictionary.eos()]
+ [
self.dict['tgt'].index(x)
for x in states.units.target.value
if x is not None
]
).unsqueeze(0)
)
# Current steps
states.incremental_states["steps"] = {
"src": states.encoder_states["encoder_out"][0].size(0),
"tgt": 1 + len(states.units.target),
}
# Online only means the reading is not finished
states.incremental_states["online"]["only"] = (
torch.BoolTensor([not states.finish_read()])
)
x, outputs = self.model.decoder.forward(
prev_output_tokens=tgt_indices,
encoder_out=states.encoder_states,
incremental_state=states.incremental_states,
)
states.decoder_out = x
torch.cuda.empty_cache()
if outputs.action == 0:
return READ_ACTION
else:
return WRITE_ACTION
def predict(self, states):
# Predict target token from decoder states
decoder_states = states.decoder_out
lprobs = self.model.get_normalized_probs(
[decoder_states[:, -1:]], log_probs=True
)
index = lprobs.argmax(dim=-1)[0, 0].item()
if index != self.dict['tgt'].eos_index:
token = self.dict['tgt'].string([index])
else:
token = self.dict['tgt'].eos_word
return token
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
model_name = file[: file.find(".py")]
importlib.import_module(
"examples.simultaneous_translation.models." + model_name
)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq import checkpoint_utils
from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_text import (
ConvTransformerModel,
convtransformer_espnet,
ConvTransformerEncoder,
)
from fairseq.models.speech_to_text.modules.augmented_memory_attention import (
augmented_memory,
SequenceEncoder,
AugmentedMemoryConvTransformerEncoder,
)
from torch import nn, Tensor
from typing import Dict, List
from fairseq.models.speech_to_text.modules.emformer import NoSegAugmentedMemoryTransformerEncoderLayer
@register_model("convtransformer_simul_trans")
class SimulConvTransformerModel(ConvTransformerModel):
"""
Implementation of the paper:
SimulMT to SimulST: Adapting Simultaneous Text Translation to
End-to-End Simultaneous Speech Translation
https://www.aclweb.org/anthology/2020.aacl-main.58.pdf
"""
@staticmethod
def add_args(parser):
super(SimulConvTransformerModel, SimulConvTransformerModel).add_args(parser)
parser.add_argument(
"--train-monotonic-only",
action="store_true",
default=False,
help="Only train monotonic attention",
)
@classmethod
def build_decoder(cls, args, task, embed_tokens):
tgt_dict = task.tgt_dict
from examples.simultaneous_translation.models.transformer_monotonic_attention import (
TransformerMonotonicDecoder,
)
decoder = TransformerMonotonicDecoder(args, tgt_dict, embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None):
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from
)
return decoder
@register_model_architecture(
"convtransformer_simul_trans", "convtransformer_simul_trans_espnet"
)
def convtransformer_simul_trans_espnet(args):
convtransformer_espnet(args)
@register_model("convtransformer_augmented_memory")
@augmented_memory
class AugmentedMemoryConvTransformerModel(SimulConvTransformerModel):
@classmethod
def build_encoder(cls, args):
encoder = SequenceEncoder(args, AugmentedMemoryConvTransformerEncoder(args))
if getattr(args, "load_pretrained_encoder_from", None) is not None:
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from
)
return encoder
@register_model_architecture(
"convtransformer_augmented_memory", "convtransformer_augmented_memory"
)
def augmented_memory_convtransformer_espnet(args):
convtransformer_espnet(args)
# ============================================================================ #
# Convtransformer
# with monotonic attention decoder
# with emformer encoder
# ============================================================================ #
class ConvTransformerEmformerEncoder(ConvTransformerEncoder):
def __init__(self, args):
super().__init__(args)
stride = self.conv_layer_stride(args)
trf_left_context = args.segment_left_context // stride
trf_right_context = args.segment_right_context // stride
context_config = [trf_left_context, trf_right_context]
self.transformer_layers = nn.ModuleList(
[
NoSegAugmentedMemoryTransformerEncoderLayer(
input_dim=args.encoder_embed_dim,
num_heads=args.encoder_attention_heads,
ffn_dim=args.encoder_ffn_embed_dim,
num_layers=args.encoder_layers,
dropout_in_attn=args.dropout,
dropout_on_attn=args.dropout,
dropout_on_fc1=args.dropout,
dropout_on_fc2=args.dropout,
activation_fn=args.activation_fn,
context_config=context_config,
segment_size=args.segment_length,
max_memory_size=args.max_memory_size,
scaled_init=True, # TODO: use constant for now.
tanh_on_mem=args.amtrf_tanh_on_mem,
)
]
)
self.conv_transformer_encoder = ConvTransformerEncoder(args)
def forward(self, src_tokens, src_lengths):
encoder_out: Dict[str, List[Tensor]] = self.conv_transformer_encoder(src_tokens, src_lengths.to(src_tokens.device))
output = encoder_out["encoder_out"][0]
encoder_padding_masks = encoder_out["encoder_padding_mask"]
return {
"encoder_out": [output],
# This is because that in the original implementation
# the output didn't consider the last segment as right context.
"encoder_padding_mask": [encoder_padding_masks[0][:, : output.size(0)]] if len(encoder_padding_masks) > 0
else [],
"encoder_embedding": [],
"encoder_states": [],
"src_tokens": [],
"src_lengths": [],
}
@staticmethod
def conv_layer_stride(args):
# TODO: make it configurable from the args
return 4
@register_model("convtransformer_emformer")
class ConvtransformerEmformer(SimulConvTransformerModel):
@staticmethod
def add_args(parser):
super(ConvtransformerEmformer, ConvtransformerEmformer).add_args(parser)
parser.add_argument(
"--segment-length",
type=int,
metavar="N",
help="length of each segment (not including left context / right context)",
)
parser.add_argument(
"--segment-left-context",
type=int,
help="length of left context in a segment",
)
parser.add_argument(
"--segment-right-context",
type=int,
help="length of right context in a segment",
)
parser.add_argument(
"--max-memory-size",
type=int,
default=-1,
help="Right context for the segment.",
)
parser.add_argument(
"--amtrf-tanh-on-mem",
default=False,
action="store_true",
help="whether to use tanh on memory vector",
)
@classmethod
def build_encoder(cls, args):
encoder = ConvTransformerEmformerEncoder(args)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from
)
return encoder
@register_model_architecture(
"convtransformer_emformer",
"convtransformer_emformer",
)
def convtransformer_emformer_base(args):
convtransformer_espnet(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, NamedTuple, Optional
import torch
import torch.nn as nn
from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
TransformerMonotonicDecoderLayer,
TransformerMonotonicEncoderLayer,
)
from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.transformer import (
TransformerModel,
TransformerEncoder,
TransformerDecoder,
base_architecture,
transformer_iwslt_de_en,
transformer_vaswani_wmt_en_de_big,
tiny_architecture
)
from torch import Tensor
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
READ_ACTION = 0
WRITE_ACTION = 1
TransformerMonotonicDecoderOut = NamedTuple(
"TransformerMonotonicDecoderOut",
[
("action", int),
("p_choose", Optional[Tensor]),
("attn_list", Optional[List[Optional[Dict[str, Tensor]]]]),
("encoder_out", Optional[Dict[str, List[Tensor]]]),
("encoder_padding_mask", Optional[Tensor]),
],
)
@register_model("transformer_unidirectional")
class TransformerUnidirectionalModel(TransformerModel):
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
@register_model("transformer_monotonic")
class TransformerModelSimulTrans(TransformerModel):
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerMonotonicDecoder(args, tgt_dict, embed_tokens)
class TransformerMonotonicEncoder(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
self.dictionary = dictionary
self.layers = nn.ModuleList([])
self.layers.extend(
[
TransformerMonotonicEncoderLayer(args)
for i in range(args.encoder_layers)
]
)
class TransformerMonotonicDecoder(TransformerDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
super().__init__(args, dictionary, embed_tokens, no_encoder_attn=False)
self.dictionary = dictionary
self.layers = nn.ModuleList([])
self.layers.extend(
[
TransformerMonotonicDecoderLayer(args)
for _ in range(args.decoder_layers)
]
)
self.policy_criterion = getattr(args, "policy_criterion", "any")
self.num_updates = None
def set_num_updates(self, num_updates):
self.num_updates = num_updates
def pre_attention(
self,
prev_output_tokens,
encoder_out_dict: Dict[str, List[Tensor]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
):
positions = (
self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
if self.embed_positions is not None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
encoder_out = encoder_out_dict["encoder_out"][0]
if "encoder_padding_mask" in encoder_out_dict:
encoder_padding_mask = (
encoder_out_dict["encoder_padding_mask"][0]
if encoder_out_dict["encoder_padding_mask"]
and len(encoder_out_dict["encoder_padding_mask"]) > 0
else None
)
else:
encoder_padding_mask = None
return x, encoder_out, encoder_padding_mask
def post_attention(self, x):
if self.layer_norm is not None:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x
def clean_cache(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
end_id: Optional[int] = None,
):
"""
Clean cache in the monotonic layers.
The cache is generated because of a forward pass of decoder has run but no prediction,
so that the self attention key value in decoder is written in the incremental state.
end_id is the last idx of the layers
"""
if end_id is None:
end_id = len(self.layers)
for index, layer in enumerate(self.layers):
if index < end_id:
layer.prune_incremental_state(incremental_state)
def extract_features(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False, # unused
alignment_layer: Optional[int] = None, # unused
alignment_heads: Optional[int] = None, # unsed
):
"""
Similar to *forward* but only return features.
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
# incremental_state = None
assert encoder_out is not None
(x, encoder_outs, encoder_padding_mask) = self.pre_attention(
prev_output_tokens, encoder_out, incremental_state
)
attn = None
inner_states = [x]
attn_list: List[Optional[Dict[str, Tensor]]] = []
p_choose = torch.tensor([1.0])
for i, layer in enumerate(self.layers):
x, attn, _ = layer(
x=x,
encoder_out=encoder_outs,
encoder_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
self_attn_mask=self.buffered_future_mask(x)
if incremental_state is None
else None,
)
inner_states.append(x)
attn_list.append(attn)
if incremental_state is not None:
if_online = incremental_state["online"]["only"]
assert if_online is not None
if if_online.to(torch.bool):
# Online indicates that the encoder states are still changing
assert attn is not None
if self.policy_criterion == "any":
# Any head decide to read than read
head_read = layer.encoder_attn._get_monotonic_buffer(incremental_state)["head_read"]
assert head_read is not None
if head_read.any():
# We need to prune the last self_attn saved_state
# if model decide not to read
# otherwise there will be duplicated saved_state
self.clean_cache(incremental_state, i + 1)
return x, TransformerMonotonicDecoderOut(
action=0,
p_choose=p_choose,
attn_list=None,
encoder_out=None,
encoder_padding_mask=None,
)
x = self.post_attention(x)
return x, TransformerMonotonicDecoderOut(
action=1,
p_choose=p_choose,
attn_list=attn_list,
encoder_out=encoder_out,
encoder_padding_mask=encoder_padding_mask,
)
@register_model_architecture("transformer_monotonic", "transformer_monotonic")
def base_monotonic_architecture(args):
base_architecture(args)
args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False)
@register_model_architecture(
"transformer_monotonic", "transformer_monotonic_iwslt_de_en"
)
def transformer_monotonic_iwslt_de_en(args):
transformer_iwslt_de_en(args)
base_monotonic_architecture(args)
# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
@register_model_architecture(
"transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_de_big"
)
def transformer_monotonic_vaswani_wmt_en_de_big(args):
transformer_vaswani_wmt_en_de_big(args)
@register_model_architecture(
"transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_fr_big"
)
def transformer_monotonic_vaswani_wmt_en_fr_big(args):
transformer_monotonic_vaswani_wmt_en_fr_big(args)
@register_model_architecture(
"transformer_unidirectional", "transformer_unidirectional_iwslt_de_en"
)
def transformer_unidirectional_iwslt_de_en(args):
transformer_iwslt_de_en(args)
@register_model_architecture("transformer_monotonic", "transformer_monotonic_tiny")
def monotonic_tiny_architecture(args):
tiny_architecture(args)
base_monotonic_architecture(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import importlib
from fairseq import registry
(
build_monotonic_attention,
register_monotonic_attention,
MONOTONIC_ATTENTION_REGISTRY,
_,
) = registry.setup_registry("--simul-type")
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
model_name = file[: file.find(".py")]
importlib.import_module(
"examples.simultaneous_translation.modules." + model_name
)
from functools import partial
import torch
from torch import Tensor
import math
import torch.nn.functional as F
from . import register_monotonic_attention
from .monotonic_multihead_attention import (
MonotonicAttention,
MonotonicInfiniteLookbackAttention,
WaitKAttention
)
from typing import Dict, Optional
def fixed_pooling_monotonic_attention(monotonic_attention):
def create_model(monotonic_attention, klass):
class FixedStrideMonotonicAttention(monotonic_attention):
def __init__(self, args):
self.waitk_lagging = 0
self.num_heads = 0
self.noise_mean = 0.0
self.noise_var = 0.0
super().__init__(args)
self.pre_decision_type = args.fixed_pre_decision_type
self.pre_decision_ratio = args.fixed_pre_decision_ratio
self.pre_decision_pad_threshold = args.fixed_pre_decision_pad_threshold
assert self.pre_decision_ratio > 1
if args.fixed_pre_decision_type == "average":
self.pooling_layer = torch.nn.AvgPool1d(
kernel_size=self.pre_decision_ratio,
stride=self.pre_decision_ratio,
ceil_mode=True,
)
elif args.fixed_pre_decision_type == "last":
def last(key):
if key.size(2) < self.pre_decision_ratio:
return key
else:
k = key[
:,
:,
self.pre_decision_ratio - 1:: self.pre_decision_ratio,
].contiguous()
if key.size(-1) % self.pre_decision_ratio != 0:
k = torch.cat([k, key[:, :, -1:]], dim=-1).contiguous()
return k
self.pooling_layer = last
else:
raise NotImplementedError
@staticmethod
def add_args(parser):
super(
FixedStrideMonotonicAttention, FixedStrideMonotonicAttention
).add_args(parser)
parser.add_argument(
"--fixed-pre-decision-ratio",
type=int,
required=True,
help=(
"Ratio for the fixed pre-decision,"
"indicating how many encoder steps will start"
"simultaneous decision making process."
),
)
parser.add_argument(
"--fixed-pre-decision-type",
default="average",
choices=["average", "last"],
help="Pooling type",
)
parser.add_argument(
"--fixed-pre-decision-pad-threshold",
type=float,
default=0.3,
help="If a part of the sequence has pad"
",the threshold the pooled part is a pad.",
)
def insert_zeros(self, x):
bsz_num_heads, tgt_len, src_len = x.size()
stride = self.pre_decision_ratio
weight = F.pad(torch.ones(1, 1, 1).to(x), (stride - 1, 0))
x_upsample = F.conv_transpose1d(
x.view(-1, src_len).unsqueeze(1),
weight,
stride=stride,
padding=0,
)
return x_upsample.squeeze(1).view(bsz_num_heads, tgt_len, -1)
def p_choose(
self,
query: Optional[Tensor],
key: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
):
assert key is not None
assert query is not None
src_len = key.size(0)
tgt_len = query.size(0)
batch_size = query.size(1)
key_pool = self.pooling_layer(key.transpose(0, 2)).transpose(0, 2)
if key_padding_mask is not None:
key_padding_mask_pool = (
self.pooling_layer(key_padding_mask.unsqueeze(0).float())
.squeeze(0)
.gt(self.pre_decision_pad_threshold)
)
# Make sure at least one element is not pad
key_padding_mask_pool[:, 0] = 0
else:
key_padding_mask_pool = None
if incremental_state is not None:
# The floor instead of ceil is used for inference
# But make sure the length key_pool at least 1
if (
max(1, math.floor(key.size(0) / self.pre_decision_ratio))
) < key_pool.size(0):
key_pool = key_pool[:-1]
if key_padding_mask_pool is not None:
key_padding_mask_pool = key_padding_mask_pool[:-1]
p_choose_pooled = self.p_choose_from_qk(
query,
key_pool,
key_padding_mask_pool,
incremental_state=incremental_state,
)
# Upsample, interpolate zeros
p_choose = self.insert_zeros(p_choose_pooled)
if p_choose.size(-1) < src_len:
# Append zeros if the upsampled p_choose is shorter than src_len
p_choose = torch.cat(
[
p_choose,
torch.zeros(
p_choose.size(0),
tgt_len,
src_len - p_choose.size(-1)
).to(p_choose)
],
dim=2
)
else:
# can be larger than src_len because we used ceil before
p_choose = p_choose[:, :, :src_len]
p_choose[:, :, -1] = p_choose_pooled[:, :, -1]
assert list(p_choose.size()) == [
batch_size * self.num_heads,
tgt_len,
src_len,
]
return p_choose
FixedStrideMonotonicAttention.__name__ = klass.__name__
return FixedStrideMonotonicAttention
return partial(create_model, monotonic_attention)
@register_monotonic_attention("waitk_fixed_pre_decision")
@fixed_pooling_monotonic_attention(WaitKAttention)
class WaitKAttentionFixedStride:
pass
@register_monotonic_attention("hard_aligned_fixed_pre_decision")
@fixed_pooling_monotonic_attention(MonotonicAttention)
class MonotonicAttentionFixedStride:
pass
@register_monotonic_attention("infinite_lookback_fixed_pre_decision")
@fixed_pooling_monotonic_attention(MonotonicInfiniteLookbackAttention)
class MonotonicInfiniteLookbackAttentionFixedStride:
pass
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch import Tensor
import torch.nn as nn
from examples.simultaneous_translation.utils.p_choose_strategy import (
learnable_p_choose,
waitk_p_choose
)
from examples.simultaneous_translation.utils.monotonic_attention import (
expected_alignment_from_p_choose,
expected_soft_attention,
mass_preservation,
)
from fairseq.modules import MultiheadAttention
from . import register_monotonic_attention
from typing import Dict, Optional
@register_monotonic_attention("hard_aligned")
class MonotonicAttention(MultiheadAttention):
"""
Abstract class of monotonic attentions
"""
k_in_proj: Dict[str, nn.Linear]
q_in_proj: Dict[str, nn.Linear]
def __init__(self, args):
super().__init__(
embed_dim=args.decoder_embed_dim,
num_heads=args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
self.soft_attention = False
self.eps = getattr(args, "attention_eps", True)
self.mass_preservation = getattr(args, "mass_preservation", True)
self.noise_type = args.noise_type
self.noise_mean = args.noise_mean
self.noise_var = args.noise_var
self.energy_bias_init = args.energy_bias_init
self.energy_bias = (
nn.Parameter(self.energy_bias_init * torch.ones([1]))
if args.energy_bias is True
else 0
)
self.k_in_proj = {"monotonic": self.k_proj}
self.q_in_proj = {"monotonic": self.q_proj}
self.chunk_size = None
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--no-mass-preservation', action="store_false",
dest="mass_preservation",
help='Do not stay on the last token when decoding')
parser.add_argument('--mass-preservation', action="store_true",
dest="mass_preservation",
help='Stay on the last token when decoding')
parser.set_defaults(mass_preservation=True)
parser.add_argument('--noise-var', type=float, default=1.0,
help='Variance of discretness noise')
parser.add_argument('--noise-mean', type=float, default=0.0,
help='Mean of discretness noise')
parser.add_argument('--noise-type', type=str, default="flat",
help='Type of discretness noise')
parser.add_argument('--energy-bias', action="store_true",
default=False,
help='Bias for energy')
parser.add_argument('--energy-bias-init', type=float, default=-2.0,
help='Initial value of the bias for energy')
parser.add_argument('--attention-eps', type=float, default=1e-6,
help='Epsilon when calculating expected attention')
def energy_from_qk(
self,
query: Tensor,
key: Tensor,
energy_type: str,
key_padding_mask: Optional[Tensor] = None,
bias: int = 0
):
"""
Compute energy from query and key
q_func_value is a tuple looks like
(q_proj_func, q_tensor)
q_tensor size: bsz, tgt_len, emb_dim
k_tensor size: bsz, src_len, emb_dim
key_padding_mask size: bsz, src_len
attn_mask: bsz, src_len
"""
length, bsz, _ = query.size()
q = self.q_in_proj[energy_type].forward(query)
q = (
q.contiguous()
.view(length, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
q = q * self.scaling
length, bsz, _ = key.size()
k = self.k_in_proj[energy_type].forward(key)
k = (
k.contiguous()
.view(length, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
energy = torch.bmm(q, k.transpose(1, 2)) + bias
if key_padding_mask is not None:
energy = energy.masked_fill(
key_padding_mask.unsqueeze(1).to(torch.bool),
- float("inf")
)
return energy
def p_choose_from_qk(self, query, key, key_padding_mask, incremental_states=None):
monotonic_energy = self.energy_from_qk(
query,
key,
"monotonic",
key_padding_mask=key_padding_mask,
bias=self.energy_bias,
)
p_choose = learnable_p_choose(
monotonic_energy,
self.noise_mean,
self.noise_var,
self.training
)
return p_choose
def p_choose(self, query, key, key_padding_mask, incremental_states=None):
return self.p_choose_from_qk(self, query, key, key_padding_mask)
def monotonic_attention_process_infer(
self,
query: Optional[Tensor],
key: Optional[Tensor],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
):
"""
Monotonic attention at inference time
Notice that this function is designed for simuleval not sequence_generator
"""
assert query is not None
assert key is not None
if query.size(1) != 1:
raise RuntimeError(
"Simultaneous translation models don't support batch decoding."
)
# 1. compute stepwise probability
p_choose = self.p_choose(
query, key, None, incremental_state
).squeeze(1)
# 2. Compute the alpha
src_len = key.size(0)
# Maximum steps allows in this iteration
max_steps = src_len - 1 if self.mass_preservation else src_len
monotonic_cache = self._get_monotonic_buffer(incremental_state)
# Step for each head
monotonic_step = monotonic_cache.get(
'head_step',
p_choose.new_zeros(1, self.num_heads).long()
)
assert monotonic_step is not None
finish_read = monotonic_step.eq(max_steps)
p_choose_i = torch.tensor(1)
while finish_read.sum().item() < self.num_heads:
# p_choose: self.num_heads, src_len
# only choose the p at monotonic steps
# p_choose_i: 1, self.num_heads
p_choose_i = (
p_choose.gather(
1,
monotonic_step
.clamp(0, src_len - 1),
)
)
read_one_step = (
(p_choose_i < 0.5)
.type_as(monotonic_step)
.masked_fill(finish_read, 0)
)
# 1 x bsz
# sample actions on unfinished seq
# 0 means stay, finish reading
# 1 means leave, continue reading
monotonic_step += read_one_step
finish_read = monotonic_step.eq(max_steps) | (read_one_step == 0)
# p_choose at last steps
p_choose_i = (
p_choose.gather(
1,
monotonic_step
.clamp(0, src_len - 1),
)
)
monotonic_cache["head_step"] = monotonic_step
# Whether a head is looking for new input
monotonic_cache["head_read"] = (
monotonic_step.eq(max_steps) & (p_choose_i < 0.5)
)
self._set_monotonic_buffer(incremental_state, monotonic_cache)
# 2. Update alpha
alpha = (
p_choose
.new_zeros([self.num_heads, src_len])
.scatter(
1,
(monotonic_step)
.view(self.num_heads, 1).clamp(0, src_len - 1),
1
)
)
if not self.mass_preservation:
alpha = alpha.masked_fill(
(monotonic_step == max_steps)
.view(self.num_heads, 1),
0
)
# 4. Compute Beta
if self.soft_attention:
monotonic_step = monotonic_step.t()
beta_mask = torch.arange(src_len).expand_as(alpha).gt(monotonic_step).unsqueeze(1)
# If it's soft attention just do softmax on current context
soft_energy = self.energy_from_qk(
query,
key,
"soft"
)
beta = torch.nn.functional.softmax(
soft_energy.masked_fill(beta_mask, -float("inf")), dim=-1
)
# It could happen that a head doesn't move at all
beta = beta.masked_fill(monotonic_step.eq(0).unsqueeze(1), 0)
else:
# If it's hard attention just select the last state
beta = alpha
return p_choose, alpha, beta
def monotonic_attention_process_train(
self,
query: Optional[Tensor],
key: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
):
"""
Calculating monotonic attention process for training
Including:
stepwise probability: p_choose
expected hard alignment: alpha
expected soft attention: beta
"""
assert query is not None
assert key is not None
# 1. compute stepwise probability
p_choose = self.p_choose_from_qk(query, key, key_padding_mask)
# 2. compute expected_alignment
alpha = expected_alignment_from_p_choose(
p_choose,
key_padding_mask,
eps=self.eps,
)
if self.mass_preservation:
alpha = mass_preservation(
alpha, key_padding_mask
)
# 3. compute expected soft attention (soft aligned model only)
if self.soft_attention:
soft_energy = self.energy_from_qk(
query,
key,
"soft",
key_padding_mask=None,
)
beta = expected_soft_attention(
alpha,
soft_energy,
padding_mask=key_padding_mask,
chunk_size=self.chunk_size,
eps=self.eps,
)
else:
beta = alpha
soft_energy = alpha
return p_choose, alpha, beta, soft_energy
def forward(
self,
query: Optional[Tensor],
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True, static_kv: bool = False, need_head_weights: bool = False,
):
"""
query: tgt_len, bsz, embed_dim
key: src_len, bsz, embed_dim
value: src_len, bsz, embed_dim
"""
assert attn_mask is None
assert query is not None
assert key is not None
assert value is not None
tgt_len, bsz, embed_dim = query.size()
src_len = value.size(0)
if key_padding_mask is not None:
assert not key_padding_mask[:, 0].any(), (
"Only right padding is supported."
)
key_padding_mask = (
key_padding_mask
.unsqueeze(1)
.expand([bsz, self.num_heads, src_len])
.contiguous()
.view(-1, src_len)
)
if incremental_state is not None:
# Inference
(
p_choose, alpha, beta
) = self.monotonic_attention_process_infer(
query, key, incremental_state
)
soft_energy = beta
else:
# Train
(
p_choose, alpha, beta, soft_energy
) = self.monotonic_attention_process_train(
query, key, key_padding_mask
)
v = self.v_proj(value)
length, bsz, _ = v.size()
v = (
v.contiguous()
.view(length, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
attn = torch.bmm(beta.type_as(v), v)
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len)
alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len)
beta = beta.view(bsz, self.num_heads, tgt_len, src_len)
return attn, {
"p_choose": p_choose,
"alpha": alpha,
"beta": beta,
}
def _get_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]):
maybe_incremental_state = self.get_incremental_state(
incremental_state,
'monotonic',
)
if maybe_incremental_state is None:
typed_empty_dict: Dict[str, Optional[Tensor]] = {}
return typed_empty_dict
else:
return maybe_incremental_state
def _set_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], buffer: Dict[str, Optional[Tensor]]):
self.set_incremental_state(
incremental_state,
'monotonic',
buffer,
)
@register_monotonic_attention("infinite_lookback")
class MonotonicInfiniteLookbackAttention(
MonotonicAttention
):
def __init__(self, args):
super().__init__(args)
self.soft_attention = True
self.init_soft_attention()
def init_soft_attention(self):
self.k_proj_soft = nn.Linear(self.kdim, self.embed_dim, bias=True)
self.q_proj_soft = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.k_in_proj["soft"] = self.k_proj_soft
self.q_in_proj["soft"] = self.q_proj_soft
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(
self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2)
)
nn.init.xavier_uniform_(
self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2)
)
else:
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight)
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight)
@register_monotonic_attention("waitk")
class WaitKAttention(
MonotonicInfiniteLookbackAttention
):
"""
STACL: Simultaneous Translation with Implicit Anticipation and
Controllable Latency using Prefix-to-Prefix Framework
https://www.aclweb.org/anthology/P19-1289/
"""
def __init__(self, args):
super().__init__(args)
self.q_in_proj["soft"] = self.q_in_proj["monotonic"]
self.k_in_proj["soft"] = self.k_in_proj["monotonic"]
self.waitk_lagging = args.waitk_lagging
assert self.waitk_lagging > 0, (
f"Lagging has to been larger than 0, get {self.waitk_lagging}."
)
@staticmethod
def add_args(parser):
super(
MonotonicInfiniteLookbackAttention,
MonotonicInfiniteLookbackAttention
).add_args(parser)
parser.add_argument(
"--waitk-lagging", type=int, required=True, help="Wait K lagging"
)
def p_choose_from_qk(
self,
query: Optional[Tensor],
key: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
):
assert query is not None
assert key is not None
p_choose = waitk_p_choose(
tgt_len=query.size(0),
src_len=key.size(0),
bsz=query.size(1) * self.num_heads,
waitk_lagging=self.waitk_lagging,
key_padding_mask=key_padding_mask,
incremental_state=incremental_state,
)
return p_choose.to(query)
@register_monotonic_attention("chunkwise")
class ChunkwiseAttention(
MonotonicInfiniteLookbackAttention
):
def __init__(self, args):
super().__init__(args)
self.chunk_size = args.mocha_chunk_size
assert self.chunk_size > 1
@staticmethod
def add_args(parser):
super(
MonotonicInfiniteLookbackAttention
).add_args(parser)
parser.add_argument(
"--mocha-chunk-size", type=int,
required=True, help="Mocha chunk size"
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
from . import build_monotonic_attention
from typing import Dict, Optional, List
from torch import Tensor
import torch
class TransformerMonotonicEncoderLayer(TransformerEncoderLayer):
def forward(self, x, encoder_padding_mask):
seq_len, _, _ = x.size()
attn_mask = x.new_ones([seq_len, seq_len]).triu(1)
attn_mask = attn_mask.masked_fill(attn_mask.bool(), float("-inf"))
return super().forward(x, encoder_padding_mask, attn_mask)
class TransformerMonotonicDecoderLayer(TransformerDecoderLayer):
def __init__(self, args):
super().__init__(args)
assert args.simul_type is not None, "A --simul-type is needed."
self.encoder_attn = build_monotonic_attention(args)
def prune_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
):
input_buffer = self.self_attn._get_input_buffer(incremental_state)
for key in ["prev_key", "prev_value"]:
input_buffer_key = input_buffer[key]
assert input_buffer_key is not None
if input_buffer_key.size(2) > 1:
input_buffer[key] = input_buffer_key[:, :, :-1, :]
else:
typed_empty_dict: Dict[str, Optional[Tensor]] = {}
input_buffer = typed_empty_dict
break
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state, input_buffer)
def forward(
self,
x,
encoder_out: Optional[Tensor] = None,
encoder_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
prev_self_attn_state: Optional[List[Tensor]] = None,
prev_attn_state: Optional[List[Tensor]] = None,
self_attn_mask: Optional[Tensor] = None,
self_attn_padding_mask: Optional[Tensor] = None,
need_attn: bool = False,
need_head_weights: bool = False,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if need_head_weights:
need_attn = True
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if prev_self_attn_state is not None:
prev_key, prev_value = prev_self_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state, saved_state)
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
if self.cross_self_attention and not (
incremental_state is not None
and _self_attn_input_buffer is not None
and "prev_key" in _self_attn_input_buffer
):
if self_attn_mask is not None:
assert encoder_out is not None
self_attn_mask = torch.cat(
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
)
if self_attn_padding_mask is not None:
if encoder_padding_mask is None:
assert encoder_out is not None
encoder_padding_mask = self_attn_padding_mask.new_zeros(
encoder_out.size(1), encoder_out.size(0)
)
self_attn_padding_mask = torch.cat(
(encoder_padding_mask, self_attn_padding_mask), dim=1
)
assert encoder_out is not None
y = torch.cat((encoder_out, x), dim=0)
else:
y = x
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
assert self.encoder_attn is not None
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
assert saved_state is not None
if self_attn_padding_mask is not None:
self_attn_state = [
saved_state["prev_key"],
saved_state["prev_value"],
saved_state["prev_key_padding_mask"],
]
else:
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
return x, attn, self_attn_state
return x, attn, None
import argparse
import unittest
from typing import Any, Dict
import torch
from examples.simultaneous_translation.models import (
transformer_monotonic_attention
)
from tests.test_roberta import FakeTask
DEFAULT_CONFIG = {
"attention_eps": 1e-6,
"mass_preservation": True,
"noise_type": "flat",
"noise_mean": 0.0,
"noise_var": 1.0,
"energy_bias_init": -2,
"energy_bias": True
}
PAD_INDEX = 1
def generate_config(overrides_kv):
new_dict = {key: value for key, value in DEFAULT_CONFIG.items()}
for key, value in overrides_kv.items():
new_dict[key] = value
return new_dict
def make_sample_with_padding(longer_src=False) -> Dict[str, Any]:
tokens_1 = torch.LongTensor(
[
[2, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 2],
[
2, 11, 12, 14, 15, 10, 11, 12, 13, 14, 15, 2,
PAD_INDEX, PAD_INDEX
],
]
)
tokens_2 = torch.LongTensor(
[
[2, 11, 12, 13, 14, 2, PAD_INDEX, PAD_INDEX],
[2, 11, 22, 33, 2, PAD_INDEX, PAD_INDEX, PAD_INDEX]
]
)
if longer_src:
src_tokens = tokens_1[:, 1:]
prev_output_tokens = tokens_2
else:
src_tokens = tokens_2[:, 1:8]
prev_output_tokens = tokens_1
src_lengths = src_tokens.ne(PAD_INDEX).sum(dim=1).long()
sample = {
"net_input": {
"src_tokens": src_tokens,
"prev_output_tokens": prev_output_tokens,
"src_lengths": src_lengths,
},
"target": prev_output_tokens[:, 1:],
}
return sample
def build_transformer_monotonic_attention(**extra_args: Any):
overrides = {
# Use characteristics dimensions
"encoder_embed_dim": 12,
"encoder_ffn_embed_dim": 14,
"decoder_embed_dim": 12,
"decoder_ffn_embed_dim": 14,
# Disable dropout so we have comparable tests.
"dropout": 0,
"attention_dropout": 0,
"activation_dropout": 0,
"encoder_layerdrop": 0,
}
overrides.update(extra_args)
# Overrides the defaults from the parser
args = argparse.Namespace(**overrides)
transformer_monotonic_attention.monotonic_tiny_architecture(args)
torch.manual_seed(0)
task = FakeTask(args)
return (
transformer_monotonic_attention
.TransformerModelSimulTrans
.build_model(args, task)
)
def expected_alignment_formula(
p_choose,
mass_perservation=True,
padding_mask=None
):
# Online and Linear-Time Attention by Enforcing Monotonic Alignments
# https://arxiv.org/pdf/1704.00784.pdf
# Eq 18, 19
bsz, tgt_len, src_len = p_choose.size()
alpha = torch.zeros_like(p_choose)
if padding_mask is not None:
bsz_pad = padding_mask.size(0)
num_heads = int(bsz / bsz_pad)
padding_mask = (
padding_mask
.unsqueeze(1)
.expand([bsz_pad, num_heads, src_len])
.contiguous()
.view(-1, src_len)
)
p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0)
for bsz_i in range(bsz):
for i in range(tgt_len):
for j in range(src_len):
if i == 0:
if j == 0:
# First source token
alpha[bsz_i, i, j] = p_choose[bsz_i, i, j]
else:
# First target token
alpha[bsz_i, i, j] = (
p_choose[bsz_i, i, j]
* torch.prod(
1 - p_choose[bsz_i, i, :j]
)
)
else:
alpha[bsz_i, i, j] = alpha[bsz_i, i - 1, j]
for k in range(j):
alpha[bsz_i, i, j] += (
alpha[bsz_i, i - 1, k]
* torch.prod(
1 - p_choose[bsz_i, i, k:j]
)
)
alpha[bsz_i, i, j] *= p_choose[bsz_i, i, j]
alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0)
if mass_perservation:
alpha = mass_perservation_formula(alpha, False, padding_mask)
return alpha
def mass_perservation_formula(alpha, left_padding=False, padding_mask=None):
if padding_mask is None or alpha.size(-1) == 1:
if alpha.size(-1) > 1:
alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1)
return alpha
src_lens = (padding_mask.logical_not()).sum(dim=1).long()
bsz, tgt_len, src_len = alpha.size()
assert (
not left_padding
or (left_padding and (not padding_mask[:, 0].any()))
)
alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0)
for bsz_i in range(bsz):
if left_padding:
alpha[bsz_i, :, -1] = (
1 - alpha[bsz_i, :, :-1].sum(dim=-1)
)
else:
alpha[bsz_i, :, src_lens[bsz_i] - 1] = (
1 - alpha[bsz_i, :, :src_lens[bsz_i] - 1].sum(dim=-1)
)
return alpha
def expected_soft_attention_formula(
alpha,
soft_energy,
padding_mask=None,
chunksize=1e10,
):
# Monotonic Infinite Lookback Attention for Simultaneous Machine Translation
# https://arxiv.org/pdf/1906.05218.pdf
# Eq 14
# Monotonic Chunkwise Attention
# https://arxiv.org/abs/1712.05382
# Eq 17
bsz, tgt_len, src_len = alpha.size()
beta = torch.zeros_like(alpha)
if padding_mask is not None:
bsz_pad = padding_mask.size(0)
num_heads = int(bsz / bsz_pad)
# Expanding for potential head dimension
padding_mask = (
padding_mask
.unsqueeze(1)
.expand([bsz_pad, num_heads, src_len])
.contiguous()
.view(-1, src_len)
)
soft_energy = soft_energy.masked_fill(padding_mask.unsqueeze(1), float('-inf'))
for bsz_i in range(bsz):
for i in range(tgt_len):
for j in range(src_len):
for k in range(j, min([src_len, j + chunksize])):
if not padding_mask[bsz_i, j]:
beta[bsz_i, i, j] += (
alpha[bsz_i, i, k] * torch.exp(soft_energy[bsz_i, i, j])
/ torch.sum(torch.exp(soft_energy[bsz_i, i, max([0, k - chunksize + 1]):k + 1]))
)
return beta
class MonotonicAttentionTestAbstractClass(object):
def test_forward(self):
sample = make_sample_with_padding()
out, _ = self.model.forward(**sample["net_input"])
loss = out.sum()
loss.backward()
def test_p_choose(self):
sample = make_sample_with_padding()
_, extra_out = self.model.forward(**sample["net_input"])
for item in extra_out.attn_list:
p_choose = item["p_choose"]
self.assertTrue(p_choose.le(1.0).all())
self.assertTrue(p_choose.ge(0.0).all())
def test_expected_alignment(self):
for longer_src in [True, False]:
sample = make_sample_with_padding(longer_src)
_, extra_out = self.model.forward(**sample["net_input"])
for item in extra_out.attn_list:
p_choose = item["p_choose"]
alpha_system = item["alpha"]
self.assertTrue(p_choose.size() == alpha_system.size())
bsz, num_head, tgt_len, src_len = alpha_system.size()
alpha_system = alpha_system.view(-1, tgt_len, src_len)
p_choose = p_choose.view(-1, tgt_len, src_len)
alpha_real = expected_alignment_formula(
p_choose,
self.model.decoder.layers[0].encoder_attn.mass_preservation,
sample["net_input"]["src_tokens"].eq(PAD_INDEX)
)
self.assertTrue(
torch.abs(alpha_system - alpha_real).le(5e-5).all(),
)
class HardMonotonicAttentionTestCase(
unittest.TestCase,
MonotonicAttentionTestAbstractClass
):
def setUp(self):
self.model = build_transformer_monotonic_attention(
**generate_config({"simul_type": "hard_aligned"})
)
class InfiniteLookbackTestCase(
unittest.TestCase,
MonotonicAttentionTestAbstractClass
):
def setUp(self):
self.model = build_transformer_monotonic_attention(
**generate_config(
{
"simul_type": "infinite_lookback"
}
)
)
self.model.train()
def test_fp16_for_long_input(self):
sample = {
"net_input": {
"src_tokens": torch.LongTensor([7] * 1000 + [2]).cuda().unsqueeze(0),
"prev_output_tokens": torch.LongTensor([7] * 1000 + [2]).cuda().unsqueeze(0),
"src_lengths": torch.LongTensor([1000]).cuda(),
},
"target": torch.LongTensor([2] + [7] * 1000).unsqueeze(0).cuda()
}
self.model.cuda().half()
_, extra_out = self.model.forward(**sample["net_input"])
for item in extra_out.attn_list:
for key in ["p_choose", "alpha", "beta", "soft_energy"]:
self.assertFalse(torch.isnan(item[key]).any())
def test_expected_attention(self):
for longer_src in [True, False]:
sample = make_sample_with_padding(longer_src)
_, extra_out = self.model.forward(**sample["net_input"])
for item in extra_out.attn_list:
p_choose = item["p_choose"]
alpha_system = item["alpha"]
beta_system = item["beta"]
soft_energy_system = item["soft_energy"]
self.assertTrue(beta_system.size() == alpha_system.size())
self.assertTrue(p_choose.size() == alpha_system.size())
bsz, num_head, tgt_len, src_len = alpha_system.size()
alpha_system = alpha_system.view(-1, tgt_len, src_len)
beta_system = beta_system.view(-1, tgt_len, src_len)
p_choose = p_choose.view(-1, tgt_len, src_len)
soft_energy_system = soft_energy_system.view(-1, tgt_len, src_len)
alpha_real = expected_alignment_formula(
p_choose,
self.model.decoder.layers[0].encoder_attn.mass_preservation,
sample["net_input"]["src_tokens"].eq(PAD_INDEX)
)
beta_real = expected_soft_attention_formula(
alpha_real,
soft_energy_system,
sample["net_input"]["src_tokens"].eq(PAD_INDEX),
chunksize=getattr(
self.model.decoder.layers[0].encoder_attn,
"chunk_size",
int(1e10)
)
)
self.assertTrue(
torch.abs(beta_system - beta_real).le(1e-5).all(),
)
class ChunkwiswTestCase(
InfiniteLookbackTestCase
):
def setUp(self):
self.model = build_transformer_monotonic_attention(
**generate_config(
{
"simul_type": "chunkwise",
"mocha_chunk_size": 3
}
)
)
class WaitkTestCase(InfiniteLookbackTestCase):
def setUp(self):
self.model = build_transformer_monotonic_attention(
**generate_config(
{
"simul_type": "waitk",
"waitk_lagging": 3,
}
)
)
def check_waitk(self, p_choose, lagging, padding_mask):
bsz, tgt_len, src_len = p_choose.size()
for bsz_i in range(bsz):
for i in range(tgt_len):
for j in range(src_len):
if not padding_mask[bsz_i, j]:
if j - i == lagging - 1:
self.assertTrue(p_choose[bsz_i, i, j] == 1)
else:
self.assertTrue(p_choose[bsz_i, i, j] == 0)
def test_waitk_p_choose(self):
for longer_src in [True, False]:
for k in [1, 3, 10, 20, 100]:
sample = make_sample_with_padding(longer_src)
model = build_transformer_monotonic_attention(
**generate_config(
{
"simul_type": "waitk",
"waitk_lagging": k,
}
)
)
model.train()
_, extra_out = model.forward(**sample["net_input"])
for item in extra_out.attn_list:
p_choose = item["p_choose"]
bsz, num_heads, tgt_len, src_len = p_choose.size()
padding_mask = sample["net_input"]["src_tokens"].eq(PAD_INDEX)
padding_mask = (
padding_mask
.unsqueeze(1)
.expand([bsz, num_heads, src_len])
.contiguous()
.view(-1, src_len)
)
p_choose = p_choose.view(bsz * num_heads, tgt_len, src_len)
self.check_waitk(p_choose, k, padding_mask)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
# automatically import any Python files in the criterions/ directory
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("examples.simultaneous_translation.utils." + module)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
def prob_check(tensor, eps=1e-10):
assert not torch.isnan(tensor).any(), (
"Nan in a probability tensor."
)
# Add the eps here to prevent errors introduced by precision
assert tensor.le(1.0 + eps).all() and tensor.ge(0.0 - eps).all(), (
"Incorrect values in a probability tensor"
", 0.0 <= tensor <= 1.0"
)
def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10):
"""
Implementing exclusive cumprod.
There is cumprod in pytorch, however there is no exclusive mode.
cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i]
exclusive means
cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i]
"""
tensor_size = list(tensor.size())
tensor_size[dim] = 1
return_tensor = safe_cumprod(
torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim),
dim=dim,
eps=eps,
)
if dim == 0:
return return_tensor[:-1]
elif dim == 1:
return return_tensor[:, :-1]
elif dim == 2:
return return_tensor[:, :, :-1]
else:
raise RuntimeError(
"Cumprod on dimension 3 and more is not implemented"
)
def safe_cumprod(tensor, dim: int, eps: float = 1e-10):
"""
An implementation of cumprod to prevent precision issue.
cumprod(x)
= [x1, x1x2, x1x2x3, ....]
= [exp(log(x1)), exp(log(x1) + log(x2)), exp(log(x1) + log(x2) + log(x3)), ...]
= exp(cumsum(log(x)))
"""
if (tensor + eps < 0).any().item():
raise RuntimeError(
"Safe cumprod can only take non-negative tensors as input."
"Consider use torch.cumprod if you want to calculate negative values."
)
log_tensor = torch.log(tensor + eps)
cumsum_log_tensor = torch.cumsum(log_tensor, dim)
exp_cumsum_log_tensor = torch.exp(cumsum_log_tensor)
return exp_cumsum_log_tensor
def moving_sum(x, start_idx: int, end_idx: int):
"""
From MONOTONIC CHUNKWISE ATTENTION
https://arxiv.org/pdf/1712.05382.pdf
Equation (18)
x = [x_1, x_2, ..., x_N]
MovingSum(x, start_idx, end_idx)_n = Sigma_{m=n−(start_idx−1)}^{n+end_idx-1} x_m
for n in {1, 2, 3, ..., N}
x : src_len, batch_size
start_idx : start idx
end_idx : end idx
Example
src_len = 5
batch_size = 3
x =
[[ 0, 5, 10],
[ 1, 6, 11],
[ 2, 7, 12],
[ 3, 8, 13],
[ 4, 9, 14]]
MovingSum(x, 3, 1) =
[[ 0, 5, 10],
[ 1, 11, 21],
[ 3, 18, 33],
[ 6, 21, 36],
[ 9, 24, 39]]
MovingSum(x, 1, 3) =
[[ 3, 18, 33],
[ 6, 21, 36],
[ 9, 24, 39],
[ 7, 17, 27],
[ 4, 9, 14]]
"""
# TODO: Make dimension configurable
assert start_idx > 0 and end_idx > 0
batch_size, tgt_len, src_len = x.size()
x = x.view(-1, src_len).unsqueeze(1)
# batch_size, 1, src_len
moving_sum_weight = torch.ones([1, 1, end_idx + start_idx - 1]).type_as(x)
moving_sum = torch.nn.functional.conv1d(
x, moving_sum_weight, padding=start_idx + end_idx - 1
).squeeze(1)
moving_sum = moving_sum[:, end_idx:-start_idx]
assert src_len == moving_sum.size(1)
assert batch_size * tgt_len == moving_sum.size(0)
moving_sum = moving_sum.view(batch_size, tgt_len, src_len)
return moving_sum
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment