Commit 72f5785f authored by huaerkl's avatar huaerkl
Browse files

v1.0

parents
Pipeline #505 canceled with stages
# Deep Transformers with Latent Depth (Li et al., 2020)
[https://arxiv.org/abs/2009.13102](https://arxiv.org/abs/2009.13102).
## Introduction
We present a probabilistic framework to automatically learn which layer(s) to use by learning the posterior distributions of layer selection. As an extension of this framework, we propose a novel method to train one shared Transformer network for multilingual machine translation with different layer selection posteriors for each language pair.
## Training a multilingual model with latent depth
Below is an example of training with latent depth in decoder for one-to-many (O2M) related languages. We use the same preprocessed (numberized and binarized) TED8 dataset as in [Balancing Training for Multilingual Neural Machine Translation (Wang et al., 2020)](https://github.com/cindyxinyiwang/multiDDS), which could be generated by [the script](https://github.com/cindyxinyiwang/multiDDS/blob/multiDDS/util_scripts/prepare_multilingual_data.sh) the author provided.
```bash
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
databin_dir=<path to binarized data>
fairseq-train ${databin_dir} \
--user-dir examples/latent_depth/latent_depth_src \
--lang-pairs "${lang_pairs_str}" \
--arch multilingual_transformer_iwslt_de_en \
--task multilingual_translation_latent_depth \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--share-encoders \
--share-decoders \
--decoder-langtok \
--share-decoder-input-output-embed \
--dropout 0.3 --attention-dropout 0.3 \
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--lr-scheduler inverse_sqrt --stop-min-lr 1e-9 --warmup-init-lr 1e-7 --warmup-updates 8000 \
--max-tokens 4096 --update-freq 1 \
--lr 0.0015 \
--clip-norm 1.0 \
--seed 2 \
--ddp-backend=legacy_ddp \
--encoder-layers 12 \
--decoder-layers 24 \
--decoder-latent-layer \
--sparsity-weight 0.1 \
--anneal-updates 5000 \
--soft-update 500 \
--target-layers 12 \
--share-weight 0.1
```
## Inference command
```bash
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
databin_dir=<path to binarized data>
model_path=<path to checkpoint>
src_lang=<source language to translate from>
tgt_lang=<target language to translate to>
gen_data=<name of data split, e.g. valid, test, etc>
fairseq-generate ${databin_dir} \
--path ${model_path} \
--task multilingual_translation_latent_depth \
--decoder-latent-layer \
--lang-pairs "${lang_pairs_str}" \
-s ${src_lang} -t ${tgt_lang} \
--gen-subset $gen_data \
--scoring sacrebleu \
--remove-bpe 'sentencepiece' \
--lenpen 1.0 \
--beam 5 \
--decoder-langtok \
--max-tokens 4096
```
## Citation
```bibtex
@article{li2020deep,
title={Deep Transformers with Latent Depth},
author={Li, Xian and Stickland, Asa Cooper and Tang, Yuqing and Kong, Xiang},
journal={arXiv preprint arXiv:2009.13102},
year={2020}
}
```
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import multilingual_translation_latent_depth # noqa
from .loss import latent_depth # noqa
from .models import latent_multilingual_transformer # noqa
from .modules import latent_layers # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch.nn.modules.loss import _Loss
class LatentLayersKLLoss(_Loss):
def __init__(self, args):
super().__init__()
self.args = args
def forward(self, layer_samples, lang_idx, update_num, sample_size):
prior = self.args.prior
samples = layer_samples[lang_idx]
eps = 1e-7
if prior == "uniform":
# uniform prior
kl_loss = (samples * (torch.log(samples + eps) - math.log(0.5))).sum(-1)
elif prior == "agged_posterior":
# aggregated posterior
y_t = torch.stack([x.detach() for x in layer_samples], dim=0)
agged_q = torch.sum(y_t, dim=0)
row_norm = agged_q.sum(-1)
normed_agg_q = agged_q / row_norm
kl_loss = (
samples * (torch.log(samples + eps) - torch.log(normed_agg_q + eps))
).sum(-1)
else:
raise NotImplementedError("The specified prior is not implemented.")
# normalized by number of layers
kl_loss /= layer_samples[0].size()[0]
kl_weight = min(
self.args.sparsity_weight,
(update_num - self.args.soft_update)
* self.args.sparsity_weight
/ self.args.anneal_updates,
)
kl_loss *= kl_weight * sample_size
return kl_loss
class LatentLayersSparsityLoss(_Loss):
def __init__(self, args):
super().__init__()
self.args = args
def is_valid(self, update_num):
if self.args.target_layers <= 0:
return False
return update_num > (self.args.soft_update + self.args.anneal_updates)
def forward(self, layer_samples_list, update_num, sample_size):
batch_loss = 0
share_loss = 0
global_sparsity_loss = 0
layer_samples = torch.stack(layer_samples_list, dim=0)
if (
self.args.target_layers > 0 or self.args.share_weight > 0
) and update_num > (self.args.soft_update + self.args.anneal_updates):
# anneal sparsity weight
if update_num < (self.args.anneal_updates + self.args.soft_update):
weight_anneal = 0
elif update_num < (2 * self.args.anneal_updates + self.args.soft_update):
weight_anneal = (
(update_num - self.args.soft_update - self.args.anneal_updates)
* self.args.share_weight
/ self.args.anneal_updates
)
else:
weight_anneal = 1
# compute ratio among languages
layer_utilization = torch.sum(layer_samples, dim=0)
layer_utilization /= layer_samples.size()[0]
if self.args.share_weight > 0:
# encouraging sharing across languages
share_loss = sum(
-1.0 * v * math.log(v) for v in layer_utilization if v > 0
)
batch_loss += (
weight_anneal * self.args.share_weight * sample_size * share_loss
)
if self.args.target_layers > 0:
# computed expected number of layers selected
expeted_layers = sum(layer_utilization)
# compute l2 loss wrt target number of layers
global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2
batch_loss += (
weight_anneal
* self.args.share_weight
* sample_size
* global_sparsity_loss
)
return batch_loss
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.models import register_model, register_model_architecture
from fairseq.models.multilingual_transformer import MultilingualTransformerModel
from fairseq.models.transformer import (
TransformerDecoder,
TransformerEncoder,
base_architecture,
)
from fairseq.utils import safe_hasattr
from .latent_transformer import LatentTransformerDecoder, LatentTransformerEncoder
@register_model("latent_multilingual_transformer")
class LatentMultilingualTransformerModel(MultilingualTransformerModel):
"""A variant of standard multilingual Transformer models which encoder and/or
decoders supports latent depth, as is in "Deep Transformer with Latent Depth"
(https://arxiv.org/abs/2009.13102).
"""
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
MultilingualTransformerModel.add_args(parser)
parser.add_argument(
'--soft-select',
action='store_true',
help='use soft samples in training an inference',
)
parser.add_argument(
'--sampling-tau',
type=float,
default=5.,
help='sampling temperature',
)
@classmethod
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
if is_encoder:
if safe_hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
return LatentTransformerEncoder(
args, lang_dict, embed_tokens, num_logits=len(langs)
)
else:
return TransformerEncoder(args, lang_dict, embed_tokens)
else:
if safe_hasattr(args, "decoder_latent_layer") and args.decoder_latent_layer:
return LatentTransformerDecoder(
args, lang_dict, embed_tokens, num_logits=len(langs)
)
else:
return TransformerDecoder(args, lang_dict, embed_tokens)
@register_model_architecture(
"latent_multilingual_transformer", "latent_multilingual_transformer"
)
def latent_multilingual_architecture(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.decoder_layers = getattr(args, "decoder_layers", 24)
args.share_encoders = getattr(args, "share_encoders", True)
args.share_decoders = getattr(args, "share_decoders", True)
args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", True)
args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", True)
base_architecture(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional
import torch.nn as nn
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import TransformerDecoder, TransformerEncoder
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
from torch import Tensor
from ..modules.latent_layers import LayerSelect
class LatentTransformerEncoder(TransformerEncoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerEncoder.
"""
def __init__(self, args, dictionary, embed_tokens, num_logits=1):
self.num_logits = num_logits
self.num_layers = args.encoder_layers
super().__init__(args, dictionary, embed_tokens)
self.layer_select = LayerSelect(
num_layers=self.num_layers,
num_logits=self.num_logits,
soft_select=getattr(args, "soft_select", False),
sampling_tau=getattr(args, "sampling_tau", 5.),
)
self.lang_idx = None
self.layers = nn.ModuleList(
[self._build_encoder_layer(args, idx) for idx in range(args.encoder_layers)]
)
def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx
def _build_encoder_layer(self, args, idx=None):
return LatentTransformerEncoderLayer(args, idx, layer_select=self.layer_select)
def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False):
self.layer_select.sample(self.lang_idx)
return super().forward(src_tokens, src_lengths, return_all_hiddens)
class LatentTransformerEncoderLayer(TransformerEncoderLayer):
"""Encoder layer with each (non_residual) block weighted by samples of Bernouli
or Gumbel Signmoid samples.
Args:
args (argparse.Namespace): parsed command-line arguments from standard
TransformerEncoderLayer.
idx (int): layer index (used to retrieve samples).
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method.
"""
def __init__(self, args, idx, layer_select=None):
super().__init__(args)
self.idx = idx
self.layer_select = layer_select
def residual_connection(self, x, residual):
return residual + x * self.layer_select(self.idx)
class LatentTransformerDecoder(TransformerDecoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerDecoder.
"""
def __init__(
self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1
):
self.num_logits = num_logits
self.num_layers = args.decoder_layers
super().__init__(
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
)
self.layer_select = LayerSelect(
num_layers=self.num_layers,
num_logits=self.num_logits,
soft_select=getattr(args, "soft_select", False),
sampling_tau=getattr(args, "sampling_tau", 5.),
)
self.lang_idx = None
self.layers = nn.ModuleList(
[
self._build_decoder_layer(args, no_encoder_attn, idx)
for idx in range(args.decoder_layers)
]
)
def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx
def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None):
return LatentTransformerDecoderLayer(
args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn
)
def forward(
self,
prev_output_tokens,
encoder_out: Optional[EncoderOut] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
):
self.layer_select.sample(self.lang_idx)
return super().forward(
prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
features_only=features_only,
alignment_layer=alignment_layer,
src_lengths=src_lengths,
return_all_hiddens=return_all_hiddens,
)
class LatentTransformerDecoderLayer(TransformerDecoderLayer):
"""Decoder layer with each (non_residual) block weighted by samples of Bernouli
or Gumbel Signmoid samples.
Args:
args (argparse.Namespace): parsed command-line arguments from standard
TransformerDecoderLayer.
idx (int): layer index (used to retrieve samples).
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method.
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self,
args,
idx,
layer_select=None,
no_encoder_attn=False,
add_bias_kv=False,
add_zero_attn=False,
):
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
self.idx = idx
self.layer_select = layer_select
def residual_connection(self, x, residual):
return residual + x * self.layer_select(self.idx)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
class LayerSelect(nn.Module):
"""Compute samples (from a Gumbel-Sigmoid distribution) which is used as
either (soft) weighting or (hard) selection of residual connection.
https://arxiv.org/abs/2009.13102
"""
def __init__(self, num_layers, num_logits, soft_select=False, sampling_tau=5.):
super(LayerSelect, self).__init__()
self.layer_logits = torch.nn.Parameter(
torch.Tensor(num_logits, num_layers),
requires_grad=True,
)
self.hard_select = not soft_select
self.tau = sampling_tau
self.detach_grad = False
self.layer_samples = [None] * num_logits
def sample(self, logit_idx):
"""To leverage the efficiency of distributed training, samples for all
layers are computed at once for each logit_idx. Logits are parameters
learnt independent of each other.
Args:
logit_idx: The index of logit parameters used for sampling.
"""
assert logit_idx is not None
self.samples = self._gumbel_sigmoid(
self.layer_logits[logit_idx, :].detach()
if self.detach_grad
else self.layer_logits[logit_idx, :],
dim=-1,
tau=self.tau,
hard=self.hard_select,
)
self.layer_samples[logit_idx] = self.samples
def forward(self, i):
sample = self.samples[i]
return sample
def _gumbel_sigmoid(
self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5
):
# ~Gumbel(0,1)
gumbels1 = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
.exponential_()
.log()
)
gumbels2 = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
.exponential_()
.log()
)
# Difference of two gumbels because we apply a sigmoid
gumbels1 = (logits + gumbels1 - gumbels2) / tau
y_soft = gumbels1.sigmoid()
if hard:
# Straight through.
y_hard = torch.zeros_like(
logits, memory_format=torch.legacy_contiguous_format
).masked_fill(y_soft > threshold, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
# Reparametrization trick.
ret = y_soft
return ret
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.tasks import register_task
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
from fairseq.utils import safe_hasattr
from .loss.latent_depth import LatentLayersKLLoss, LatentLayersSparsityLoss
@register_task("multilingual_translation_latent_depth")
class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
"""A task for multiple translation with latent depth.
See `"Deep Transformer with Latent Depth"
(Li et al., 2020) <https://arxiv.org/pdf/2009.13102.pdf>`_.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
MultilingualTranslationTask.add_args(parser)
parser.add_argument('--encoder-latent-layer', action='store_true', help='latent layer selection in encoder')
parser.add_argument('--decoder-latent-layer', action='store_true', help='latent layer selection in decoder')
parser.add_argument('--target-layers', default=-1, type=int,
help='number of effective layers to learn; -1 means no constraint')
parser.add_argument('--sparsity-weight', default=0.0, type=float,
help='weight for sparsity loss')
parser.add_argument('--share-weight', default=0.0, type=float,
help='weight for sharing loss')
parser.add_argument('--soft-update', default=1, type=int,
help='number of updates with soft sampling')
parser.add_argument('--anneal-updates', default=1, type=int,
help='number of updates to anneal the KL loss weight')
parser.add_argument('--prior', default="uniform", type=str,
help='prior used for computing KL loss')
# fmt: on
def __init__(self, args, dicts, training):
super().__init__(args, dicts, training)
self.src_langs, self.tgt_langs = zip(
*[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs]
)
if self.training and self.encoder_latent_layer:
assert self.args.share_encoders
if self.training and self.decoder_latent_layer:
assert self.args.share_decoders
if training or self.encoder_latent_layer or self.decoder_latent_layer:
self.lang_pairs = args.lang_pairs
else:
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
self.eval_lang_pairs = self.lang_pairs
self.model_lang_pairs = self.lang_pairs
if self.training and (self.encoder_latent_layer or self.decoder_latent_layer):
self.kl_loss = LatentLayersKLLoss(self.args)
self.sparsity_loss = LatentLayersSparsityLoss(self.args)
def _per_lang_pair_train_loss(
self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad
):
src, tgt = lang_pair.split("-")
if self.encoder_latent_layer:
src_lang_idx = self.src_lang_idx_dict[src]
model.models[lang_pair].encoder.set_lang_idx(src_lang_idx)
model.models[lang_pair].encoder.layer_select.hard_select = (
update_num > self.args.soft_update
)
if self.decoder_latent_layer:
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
model.models[lang_pair].decoder.layer_select.hard_select = (
update_num > self.args.soft_update
)
loss, sample_size, logging_output = criterion(
model.models[lang_pair], sample[lang_pair]
)
if self.encoder_latent_layer:
none_samples = sum(
1 if x is None else 0
for x in model.models[lang_pair].encoder.layer_select.layer_samples
)
if none_samples == 0 or self.args.prior != "agged_posterior":
loss += self.kl_loss(
model.models[lang_pair].encoder.layer_select.layer_samples,
src_lang_idx,
update_num,
sample_size,
)
if self.decoder_latent_layer:
none_samples = sum(
1 if x is None else 0
for x in model.models[lang_pair].decoder.layer_select.layer_samples
)
if none_samples == 0 or self.args.prior != "agged_posterior":
loss += self.kl_loss(
model.models[lang_pair].decoder.layer_select.layer_samples,
tgt_lang_idx,
update_num,
sample_size,
)
if ignore_grad:
loss *= 0
if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num):
# need to retain the graph if sparsity loss needs to be added
loss.backward(retain_graph=True)
else:
optimizer.backward(loss)
return loss, sample_size, logging_output
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
agg_loss, agg_sample_size, agg_logging_output = super().train_step(
sample, model, criterion, optimizer, update_num, ignore_grad
)
# compute auxiliary loss from layere sparsity, based on all samples from all languages
if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num):
sparsity_loss = 0
if self.encoder_latent_layer:
sparsity_loss += self.sparsity_loss(
next(
iter(model.models.values())
).encoder.layer_select.layer_samples,
update_num,
agg_sample_size,
)
if self.decoder_latent_layer:
sparsity_loss += self.sparsity_loss(
next(
iter(model.models.values())
).decoder.layer_select.layer_samples,
update_num,
agg_sample_size,
)
if sparsity_loss > 0:
optimizer.backward(sparsity_loss)
return agg_loss, agg_sample_size, agg_logging_output
def _per_lang_pair_valid_loss(self, lang_pair, model, criterion, sample):
src, tgt = lang_pair.split("-")
if self.encoder_latent_layer:
src_lang_idx = self.src_lang_idx_dict[src]
model.models[lang_pair].encoder.set_lang_idx(src_lang_idx)
if self.decoder_latent_layer:
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
loss, sample_size, logging_output = criterion(
model.models[lang_pair], sample[lang_pair]
)
return loss, sample_size, logging_output
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
if self.encoder_latent_layer or self.decoder_latent_layer:
for model in models:
if self.encoder_latent_layer:
assert model.encoder.layer_select is not None
src_lang_idx = self.src_lang_idx_dict[self.args.source_lang]
model.encoder.set_lang_idx(src_lang_idx)
if self.decoder_latent_layer:
assert model.decoder.layer_select is not None
tgt_lang_idx = self.tgt_lang_idx_dict[self.args.target_lang]
model.decoder.set_lang_idx(tgt_lang_idx)
return super().inference_step(
generator, models, sample, prefix_tokens, constraints
)
@property
def encoder_latent_layer(self):
return (
safe_hasattr(self.args, "encoder_latent_layer")
and self.args.encoder_latent_layer
)
@property
def decoder_latent_layer(self):
return (
safe_hasattr(self.args, "decoder_latent_layer")
and self.args.decoder_latent_layer
)
@property
def src_lang_idx_dict(self):
return {lang: lang_idx for lang_idx, lang in enumerate(self.src_langs)}
@property
def tgt_lang_idx_dict(self):
return {lang: lang_idx for lang_idx, lang in enumerate(self.tgt_langs)}
# Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)
This page contains information for how to train models with LayerDrop, based on this [paper](https://arxiv.org/abs/1909.11556).
## Citation:
If you found this technique useful, please cite our paper:
```bibtex
@article{fan2019reducing,
title={Reducing Transformer Depth on Demand with Structured Dropout},
author={Fan, Angela and Grave, Edouard and Joulin, Armand},
journal={arXiv preprint arXiv:1909.11556},
year={2019}
}
```
## Pre-trained models
Model | Description | Download
---|---|---
`layerdrop_wmt_en_de_12_6` | Transformer + LayerDrop 0.2 trained on WMT16 en-de with 12 encoder and 6 decoder layers | [layerdrop_wmt_en_de_12_6.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/layerdrop_wmt_en_de_12_6.tar.gz)
`roberta_layerdrop.base` | RoBERTa Base + LayerDrop 0.2 | [roberta_layerdrop.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta_layerdrop.base.qnli.tar.gz)
`roberta_layerdrop.large` | RoBERTa Large + LayerDrop 0.2 | [roberta_layerdrop.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta_layerdrop.large.tar.gz)
`roberta_layerdrop.large.mnli` | `roberta_layerdrop.large` finetuned on [MNLI](http://www.nyu.edu/projects/bowman/multinli) | [roberta_layerdrop.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta_layerdrop.large.mnli.tar.gz)
`roberta_layerdrop.large.qnli` | `roberta_layerdrop.large` finetuned on [QNLI](https://arxiv.org/abs/1804.07461) | [roberta_layerdrop.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta_layerdrop.large.qnli.tar.gz)
Evaluate performance of these pre-trained models:
```bash
# Example for Machine Translation
fairseq-generate /path/to/bped/wmt/data --path nmt_checkpoint.pt \
--beam 8 --lenpen 0.4 \
--batch-size 64 \
--remove-bpe \
--gen-subset test > wmt16_gen.txt
bash scripts/compound_split_bleu.sh wmt16_gen.txt
# prints BLEU4 = 30.17
```
```python
# Example for RoBERTa + LayerDrop finetuned on MNLI:
from fairseq.models.roberta import RobertaModel
roberta_layerdrop = RobertaModel.from_pretrained(
'/path/to/MNLI/model',
checkpoint_file='mnli_checkpoint.pt',
data_name_or_path='/path/to/MNLI/data/MNLI-bin'
)
label_map = {0: 'contradiction', 2: 'neutral', 1: 'entailment'}
ncorrect, nsamples = 0, 0
roberta_layerdrop.cuda()
roberta_layerdrop.eval()
with open('/path/to/MNLI/data/dev_matched.tsv') as fin:
fin.readline()
for index, line in enumerate(fin):
tokens = line.strip().split('\t')
sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
tokens = roberta_layerdrop.encode(sent1, sent2)
prediction = roberta_layerdrop.predict('sentence_classification_head', tokens).argmax().item()
prediction_label = label_map[prediction]
ncorrect += int(prediction_label == target)
nsamples += 1
print('| Accuracy: ', float(ncorrect)/float(nsamples))
# prints | Accuracy: 0.9026999490575649
# Example for RoBERTa + LayerDrop finetuned on QNLI:
roberta = RobertaModel.from_pretrained(
'/path/to/QNLI/model',
checkpoint_file='qnli_checkpoint.pt',
data_name_or_path='/path/to/QNLI/data/QNLI-bin'
)
label_fn = lambda label: roberta.task.label_dictionary.string(
[label + roberta.task.target_dictionary.nspecial]
)
ncorrect, nsamples = 0, 0
roberta.cuda()
roberta.eval()
with open('/path/to/QNLI/data/dev.tsv') as fin:
fin.readline()
for index, line in enumerate(fin):
tokens = line.strip().split('\t')
sent1, sent2, target = tokens[1], tokens[2], tokens[3]
tokens = roberta.encode(sent1, sent2)
prediction = roberta.predict('sentence_classification_head', tokens).argmax().item()
prediction_label = label_fn(prediction)
ncorrect += int(prediction_label == target)
nsamples += 1
print('| Accuracy: ', float(ncorrect)/float(nsamples))
# prints | Accuracy: 0.9480139117700896
```
## Example usage
To train a model with LayerDrop, add the following flags. We recommend 0.2, a value that worked well in our experiments. For Language Models that are decoder-only, you need only the decoder flag. For RoBERTa, an encoder, you need only the encoder flag. The encoder and decoder LayerDrop values can be set differently.
```
--encoder-layerdrop 0.2 --decoder-layerdrop 0.2
```
To prune a model that has been trained with LayerDrop, add the following flags followed by a comma separated list of which layers you would like to keep.
```
--encoder-layers-to-keep 0,2,4,6,8,10,12,14 --decoder-layers-to-keep 0,2,4,6,8,10,12,14
```
Setting these flags should print a message such as:
```
| Pruning model to specified layer configuration
```
You should also see a smaller number of parameters in the model, for example the 16-Layer Transformer Language Model prints:
```
num. model params: 246933504
```
while a model pruned to 8 Layers prints:
```
num. model params: 146163712
```
If you would like to pick up training with a model that has been pruned, simply adding these flags is sufficient. If you would like to use a script that only does evaluation (no training), you may need to pass an override command. A specific example would be for language modeling:
```bash
fairseq-eval-lm /path/to/wikitext-103 \
--path /path/to/model/checkpoint.pt \
--model-overrides "{'decoder_layers_to_keep':'0,2,4,6,8,10,12,14'}"
```
This model override command overrides the training parameters and updates the model arguments so that the pruned model is run instead of the full model.
## Reproduce Paper Results
Looking to reproduce the results in the paper?
1. For Translation on WMT16 en-de, we followed this setting [here](https://github.com/pytorch/fairseq/blob/main/examples/scaling_nmt/README.md)
2. To train RoBERTa, we followed this setting [here](https://github.com/pytorch/fairseq/tree/main/examples/roberta)
3. To train Language Models on Wikitext-103, we followed this setting [here](https://github.com/pytorch/fairseq/tree/main/examples/language_model)
## Tips
1. If you would like to train large models with better performance, LayerDrop should be set to a smaller value such as 0.1 or 0.2. Too much LayerDrop will mean the model has too much regularization, so may not reach the best performance. Since LayerDrop adds regularization, you may achieve the best performance by slightly reducing the amount of standard dropout (for example, reduce by 0.1).
2. If you would like to train large models to be pruned and made smaller, LayerDrop should be set to a larger value such as 0.5 if you want to prune very aggressively (such as removing half the network or more). If you would like to prune fewer layers away, LayerDrop can be set to a smaller value such as 0.2. Our experiments were conducted with low values of LayerDrop (such as 0.1 and 0.2), for reference.
3. When pruning layers at inference time, it is best to spread out the layers remaining so they are evenly spaced throughout the network. For example, if you want to remove 50% of the network, keeping every other layer is good.
## FAQ
1. How did the sharing layers experiment work? In an appendix (https://openreview.net/pdf?id=SylO2yStDr) we added an experiment on Wikitext-103 language modeling that combined LayerDrop with Weight Sharing. We shared chunks of 2 layers such that every other layer had shared weights. For example, if our network has layers 1 through 6, then layer 1 and 2 are shared, layer 3 and 4 are shared, and layer 5 and 6 are shared.
2. LayerDrop hasn't been helping in my setting? During training time, LayerDrop can help regularize your network. This is most important if your network is already overfitting - if your network is underfitting, it is possible LayerDrop is adding too much regularization. We recommend using smaller values (such as 0.1 or 0.2) and also decreasing the quantity of standard dropout (for example, reduce by 0.1).
3. Can you train a model without LayerDrop and finetune with LayerDrop (e.g. for BERT)? In our experiments, we did not see great performance. Models such as RoBERTa have trained for a long time in the pre-training setting, so only finetuning with LayerDrop for a few epochs on a downstream task such as MNLI does not achieve the robustness required for successful pruning.
## Having an issue or have a question?
Please open an issue in this repository with the details of your question. Thanks!
# Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)
This example contains code to train Linformer models as described in our paper
[Linformer: Self-Attention with Linear Complexity](https://arxiv.org/abs/2006.04768).
## Training a new Linformer RoBERTa model
You can mostly follow the [RoBERTa pretraining README](/examples/roberta/README.pretraining.md),
updating your training command with `--user-dir examples/linformer/linformer_src --arch linformer_roberta_base`.
## Citation
If you use our work, please cite:
```bibtex
@article{wang2020linformer,
title={Linformer: Self-Attention with Linear Complexity},
author={Wang, Sinong and Li, Belinda and Khabsa, Madian and Fang, Han and Ma, Hao},
journal={arXiv preprint arXiv:2006.04768},
year={2020}
}
```
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .models import linformer_roberta # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Linformer: Self-Attention with Linear Complexity
"""
import logging
import torch
from fairseq import utils
from fairseq.models import register_model, register_model_architecture
from fairseq.models.roberta import (
init_bert_params,
roberta_base_architecture,
roberta_large_architecture,
RobertaEncoder,
RobertaModel,
)
from fairseq.utils import safe_hasattr
from ..modules.linformer_sentence_encoder import LinformerTransformerEncoder
logger = logging.getLogger(__name__)
@register_model("linformer_roberta")
class LinformerModel(RobertaModel):
@staticmethod
def add_args(parser):
RobertaModel.add_args(parser)
# add args for Linformer
parser.add_argument(
"--compressed", type=int, help="compressed ratio of sequence length"
)
parser.add_argument(
"--shared-kv-compressed",
type=int,
help="share compressed matrix between k and v, in each layer",
)
parser.add_argument(
"--shared-layer-kv-compressed",
type=int,
help="share compressed matrix between k and v and across all layers",
)
parser.add_argument(
"--freeze-compress",
type=int,
help="freeze the parameters in compressed layer",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present
base_architecture(args)
if not safe_hasattr(args, "max_positions"):
args.max_positions = args.tokens_per_sample
encoder = LinformerEncoder(args, task.source_dictionary)
return cls(args, encoder)
class LinformerEncoder(RobertaEncoder):
"""Linformer encoder."""
def __init__(self, args, dictionary):
super().__init__(args, dictionary)
self.register_buffer("version", torch.tensor(2))
def build_encoder(self, args, dictionary, embed_tokens):
encoder = LinformerTransformerEncoder(args, dictionary, embed_tokens)
encoder.apply(init_bert_params)
return encoder
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
prefix = name + "." if name != "" else ""
# some old checkpoints had weight sharing implemented incorrectly
# (note: this was correct in the original paper code)
if utils.item(state_dict.get(f"{prefix}version", torch.tensor(1))) < 2:
state_dict[f"{prefix}version"] = torch.tensor(1)
# check if input embeddings and output embeddings were tied
if not torch.allclose(
state_dict[f"{prefix}sentence_encoder.embed_tokens.weight"],
state_dict[f"{prefix}lm_head.weight"],
):
# they weren't tied, re-init the LM head without weight sharing
self.lm_head = self.build_lm_head(
embed_dim=self.args.encoder_embed_dim,
output_dim=len(self.dictionary),
activation_fn=self.args.activation_fn,
weight=None, # don't share weights
)
@register_model_architecture("linformer_roberta", "linformer_roberta")
def base_architecture(args):
args.compressed = getattr(args, "compressed", 4)
args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0)
args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0)
args.freeze_compress = getattr(args, "freeze_compress", 0)
roberta_base_architecture(args)
@register_model_architecture("linformer_roberta", "linformer_roberta_base")
def linformer_roberta_base_architecture(args):
base_architecture(args)
@register_model_architecture("linformer_roberta", "linformer_roberta_large")
def linformer_roberta_large_architecture(args):
roberta_large_architecture(args)
base_architecture(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch.nn as nn
from fairseq.models.transformer import TransformerEncoder
from .linformer_sentence_encoder_layer import LinformerTransformerEncoderLayer
class LinformerTransformerEncoder(TransformerEncoder):
"""
Implementation for a Bi-directional Linformer based Sentence Encoder used
in BERT/XLM style pre-trained models.
This first computes the token embedding using the token embedding matrix,
position embeddings (if specified) and segment embeddings
(if specified). After applying the specified number of
LinformerEncoderLayers, it outputs all the internal states of the
encoder as well as the final representation associated with the first
token (usually CLS token).
Input:
- tokens: B x T matrix representing sentences
- segment_labels: B x T matrix representing segment label for tokens
Output:
- a tuple of the following:
- a list of internal model states used to compute the
predictions where each tensor has shape T x B x C
- sentence representation associated with first input token
in format B x C.
"""
def __init__(self, args, dictionary, embed_tokens):
self.compress_layer = None
super().__init__(args, dictionary, embed_tokens)
def build_encoder_layer(self, args):
if self.args.shared_layer_kv_compressed == 1 and self.compress_layer is None:
compress_layer = nn.Linear(
self.args.max_positions,
self.args.max_positions // self.args.compressed,
)
# intialize parameters for compressed layer
nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2))
if self.args.freeze_compress == 1:
compress_layer.weight.requires_grad = False
self.compress_layer = compress_layer
return LinformerTransformerEncoderLayer(args, self.compress_layer)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq import utils
from fairseq.modules import TransformerEncoderLayer
from .multihead_linear_attention import MultiheadLinearAttention
class LinformerTransformerEncoderLayer(TransformerEncoderLayer):
"""
Implements a Linformer Encoder Layer used in BERT/XLM style pre-trained
models.
"""
def __init__(self, args, shared_compress_layer):
# wrap in a list so it's not automatically registered by PyTorch
self.shared_compress_layer = [shared_compress_layer]
super().__init__(args)
self.register_buffer("version", torch.tensor(2))
def build_self_attention(self, embed_dim, args):
return MultiheadLinearAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.dropout,
self_attention=True,
q_noise=args.quant_noise_pq,
qn_block_size=args.quant_noise_pq_block_size,
compressed=args.compressed,
max_seq_len=args.max_positions,
shared_kv_compressed=args.shared_kv_compressed,
shared_compress_layer=self.shared_compress_layer[0],
freeze_compress=args.freeze_compress,
)
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
prefix = name + "." if name != "" else ""
# some old checkpoints had weight sharing implemented incorrectly
# (note: this was correct in the original paper code)
if utils.item(state_dict.get(f"{prefix}version", torch.tensor(1))) < 2:
state_dict[f"{prefix}version"] = torch.tensor(1)
# check compression layer sharing
if f"{prefix}shared_compress_layer.weight" in state_dict:
# reinitialize block without sharing compression layer to match
# old behavior
self.shared_compress_layer = [
torch.nn.Linear(
self.shared_compress_layer[0].weight.size(1),
self.shared_compress_layer[0].weight.size(0),
)
]
self.self_attn = self.build_self_attention(self.embed_dim, self.args)
# delete shared_compress_layer, since it's already copied to
# self_attn.compress_k.weight
del state_dict[f"{prefix}shared_compress_layer.weight"]
if f"{prefix}shared_compress_layer.bias" in state_dict:
del state_dict[f"{prefix}shared_compress_layer.bias"]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor, nn
from torch.nn import Parameter
@with_incremental_state
class MultiheadLinearAttention(nn.Module):
"""Multi-headed linformer attention.
Projects the key and values down to the compressed dimension, before computing self-attention.
See "Linformer: Self-Attention with Linear Complexity" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
compressed=1,
max_seq_len=256,
shared_kv_compressed=0,
shared_compress_layer=None,
freeze_compress=0,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
)
self.k_proj = quant_noise(
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
# used for compress sequence to subsequence
if shared_compress_layer is None:
self.compress_seq_len = max_seq_len // compressed
self.compress_k = nn.Linear(max_seq_len, self.compress_seq_len, bias=False)
if shared_kv_compressed == 0:
self.compress_v = nn.Linear(
max_seq_len, self.compress_seq_len, bias=False
)
self.layerwise_sharing = False
else:
self.compress_k = shared_compress_layer
if shared_kv_compressed == 0:
self.compress_v = shared_compress_layer
self.layerwise_sharing = True
self.shared_kv_compressed = shared_kv_compressed
self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.reset_parameters()
if freeze_compress == 1:
self.compress_k.weight.requires_grad = False
if shared_kv_compressed == 0:
self.compress_v.weight.requires_grad = False
self.onnx_trace = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
if (
not self.layerwise_sharing
): # otherwise, we already initialize the parameters
nn.init.xavier_uniform_(self.compress_k.weight, gain=1 / math.sqrt(2))
if self.shared_kv_compressed == 0:
nn.init.xavier_uniform_(
self.compress_v.weight, gain=1 / math.sqrt(2)
)
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
if (
not self.layerwise_sharing
): # otherwise, we already initialize the parameters
nn.init.xavier_uniform_(self.compress_k.weight)
if self.shared_kv_compressed == 0:
nn.init.xavier_uniform_(self.compress_v.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k_input = query.permute(1, 2, 0).contiguous() # B * C * T
k_input = (
F.linear(k_input, self.compress_k.weight[:, 0:tgt_len])
.permute(2, 0, 1)
.contiguous()
)
k = self.k_proj(k_input)
v_input = query.permute(1, 2, 0).contiguous() # B * C * T
if self.shared_kv_compressed == 0:
v_input = (
F.linear(v_input, self.compress_v.weight[:, 0:tgt_len])
.permute(2, 0, 1)
.contiguous()
)
if self.shared_kv_compressed == 1: # use shared kv compressed linear layer
v_input = (
F.linear(v_input, self.compress_k.weight[:, 0:tgt_len])
.permute(2, 0, 1)
.contiguous()
)
v = self.v_proj(v_input)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadLinearAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
src_len = k.size(1)
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = MultiheadLinearAttention.apply_sparse_mask(
attn_weights, tgt_len, src_len, bsz
)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
if self.onnx_trace:
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
attn_weights += attn_mask
if before_softmax:
return attn_weights, v
attn_weights_float = utils.softmax(
attn_weights, dim=-1, onnx_trace=self.onnx_trace
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = F.dropout(
attn_weights,
p=self.dropout,
training=self.training,
)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
if self.onnx_trace and attn.size(1) == 1:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
else:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights
@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
filler = torch.zeros(
(batch_size, src_len - prev_key_padding_mask.size(1)),
device=prev_key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), filler.float()], dim=1
)
elif key_padding_mask is not None:
filler = torch.zeros(
(batch_size, src_len - key_padding_mask.size(1)),
device=key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[filler.float(), key_padding_mask.float()], dim=1
)
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
@torch.jit.export
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
input_buffer_k = input_buffer[k]
if input_buffer_k is not None:
if self.encoder_decoder_attention and input_buffer_k.size(
0
) == new_order.size(0):
break
input_buffer[k] = input_buffer_k.index_select(0, new_order)
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
return incremental_state
def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
empty_result: Dict[str, Optional[Tensor]] = {}
return empty_result
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)
def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
return attn_weights
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + "." if name != "" else ""
items_to_add = {}
keys_to_remove = []
for k in state_dict.keys():
if k.endswith(prefix + "in_proj_weight"):
# in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
keys_to_remove.append(k)
k_bias = prefix + "in_proj_bias"
if k_bias in state_dict.keys():
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
dim : 2 * dim
]
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
keys_to_remove.append(prefix + "in_proj_bias")
for k in keys_to_remove:
del state_dict[k]
for key, value in items_to_add.items():
state_dict[key] = value
# Beyond English-Centric Multilingual Machine Translation
## Introduction
In this work, we create a true Many-to-Many multilingual translation model that can translate directly between any pair of 100 languages. Our focus on non-English-Centric models brings gains of more than 10 BLEU when directly translating between non-English directions while performing competitively with the best single systems of WMT.
If you are new to using fairseq, read the following walkthrough. Otherwise, skip to the sections below.
0. **Generation Data**
To download the generation data, follow the below commands. Note that all datasets need to be detokenized *before* applying SPM in the data preprocessing step. If you use these evaluation datasets, please cite their associated papers.
```bash
# WMT - use sacrebleu, example here:
sacrebleu -t wmt14 -l fr-en --echo src > wmt.test.fr-en.fr
sacrebleu -t wmt14 -l fr-en --echo ref > wmt.test.fr-en.en
# WAT
wget http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/wat2020.my-en.zip
unzip wat2020.my-en.zip
# FLORES
# download from: https://github.com/facebookresearch/flores
# TED - need to detokenize with Moses!
# from: https://github.com/neulab/word-embeddings-for-nmt
wget http://phontron.com/data/ted_talks.tar.gz
# Autshumato
# request to download: https://repo.sadilar.org/handle/20.500.12185/397
# Tatoeba Challenge
# available here: https://github.com/Helsinki-NLP/Tatoeba-Challenge
```
1. **Training Data**
To produce the training data, we use a combination of [CCMatrix](https://arxiv.org/abs/1911.04944) and [CCAligned](https://arxiv.org/abs/1911.06154). Check out the instructions [here](https://github.com/facebookresearch/LASER/tree/master/tasks/CCMatrix) to download the raw data.
2. **Preprocess Data**
After downloading raw data, you will need to postprocess the data, then apply SPM, then binarize. Note that it is very important you run the postprocessing script, because this removes any instance of the evaluation data in the mined training data.
```bash
# preprocess data
# remove sentences with more than 50% punctuation
python /path/to/fairseq/examples/m2m_100/process_data/remove_too_much_punc.py
# deduplicate training data
paste /path/to/datadir/train.$src /path/to/datadir/train.$tgt | awk '!x[$0]++' > /path/to/datadir/train.dedup
echo "keeping $(wc -l /path/to/datadir/train.dedup) bitext out of $(wc -l /path/to/datadir/train.$src)"
cut -f1 /path/to/datadir/train.dedup > /path/to/datadir/train.$src
cut -f2 /path/to/datadir/train.dedup > /path/to/datadir/train.$tgt
# remove all instances of evaluation data from the training data
python /path/to/fairseq/examples/m2m_100/process_data/dedup_data.py
# frequency cleaning
wget https://dl.fbaipublicfiles.com/m2m_100/histograms.tar.gz
tar -xvzf histograms.tar.gz
python /path/to/fairseq/examples/m2m_100/process_data/clean_histogram.py --src $src --tgt $tgt --src-file /path/to/source/file --tgt-file /path/to/output/file --src-output-file source_output.$src --tgt-output-file target_output.$tgt --histograms /path/to/histograms
# apply SPM
wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model
python /path/to/fairseq/scripts/spm_encode.py \
--model spm.128k.model \
--output_format=piece \
--inputs=/path/to/input/file/here \
--outputs=/path/to/output/file/here
# length ratio cleaning
perl mosesdecoder/scripts/training/clean-corpus-n.perl --ratio 3 /path/to/training/data/train.spm.$src-$tgt $src $tgt /path/to/output/directory/train.spm.$src-$tgt 1 250
# binarize data
wget https://dl.fbaipublicfiles.com/m2m_100/data_dict.128k.txt
fairseq-preprocess \
--source-lang $src --target-lang $tgt \
--testpref spm.$src.$tgt \
--thresholdsrc 0 --thresholdtgt 0 \
--destdir data_bin \
--srcdict data_dict.128k.txt --tgtdict data_dict.128k.txt
```
3. **Training Scripts**
To reproduce the training of our models, we train with fairseq-py's multilingual translation [task](https://github.com/pytorch/fairseq/tree/main/examples/multilingual). If you are interested in model parallel training, also check out [fairscale](https://github.com/facebookresearch/fairscale).
4. **Generation**
To generate from our models, follow the the commands in the generation section below.
If you use any of the resources listed here, please cite:
```bibtex
@article{fan2020beyond,
title={Beyond English-Centric Multilingual Machine Translation},
author={Fan, Angela and Bhosale, Shruti and Schwenk, Holger and Ma, Zhiyi and El-Kishky, Ahmed and Goyal, Siddharth and Baines, Mandeep and Celebi, Onur and Wenzek, Guillaume and Chaudhary, Vishrav and Goyal, Naman and Birch, Tom and Liptchinsky, Vitaliy and Edunov, Sergey and Grave, Edouard and Auli, Michael and Joulin, Armand},
journal={arXiv preprint},
year={2020}
}
@article{schwenk2019ccmatrix,
title={Ccmatrix: Mining billions of high-quality parallel sentences on the web},
author={Schwenk, Holger and Wenzek, Guillaume and Edunov, Sergey and Grave, Edouard and Joulin, Armand},
journal={arXiv preprint arXiv:1911.04944},
year={2019}
}
@article{el2019massive,
title={A Massive Collection of Cross-Lingual Web-Document Pairs},
author={El-Kishky, Ahmed and Chaudhary, Vishrav and Guzman, Francisco and Koehn, Philipp},
journal={arXiv preprint arXiv:1911.06154},
year={2019}
}
```
## Trained Models
### 418M and 1.2B Model
We include the last checkpoint for both of these models.
```bash
wget https://dl.fbaipublicfiles.com/m2m_100/model_dict.128k.txt
wget https://dl.fbaipublicfiles.com/m2m_100/language_pairs_small_models.txt
# 418M parameter model
wget https://dl.fbaipublicfiles.com/m2m_100/418M_last_checkpoint.pt
# 1.2B parameter model
wget https://dl.fbaipublicfiles.com/m2m_100/1.2B_last_checkpoint.pt
# Generation:
fairseq-generate $binarized_data_path --batch-size 32 --path $path_to_model --fixed-dictionary model_dict.128k.txt -s en -t fr --remove-bpe 'sentencepiece' --beam 5 --task translation_multi_simple_epoch --lang-pairs language_pairs_small_models.txt --decoder-langtok --encoder-langtok src --gen-subset test > gen_out
```
### 12B Model
12B parameter model trained on many-to-many training data for 100 languages. We include the last checkpoint, average of last 5 checkpoints, average of last 10 checkpoints. There isn't a universally best choice out of these three, but all three versions are pretty close in accuracy. You can either sweep over the 3 checkpoints on a dev test and use the best performing checkpoint for final testing. Or the last checkpoint can be a good default choice.
**Model Download Links**
Configuration | 2 32GB GPUs | 4 16GB GPUs | 6 12GB GPUs | 8 8GB GPUs
:--|:--|:--|:--|:--
Last Checkpoint | [12b_last_chk_2_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_last_chk_2_gpus.pt) | [12b_last_chk_4_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_last_chk_4_gpus.pt) | [12b_last_chk_6_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_last_chk_6_gpus.pt) | [12b_last_chk_8_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_last_chk_8_gpus.pt)
Average of last 5 checkpoints | [12b_avg5_chk_2_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg5_chk_2_gpus.pt) | [12b_avg5_chk_4_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg5_chk_4_gpus.pt) | [12b_avg5_chk_6_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg5_chk_6_gpus.pt) | [12b_avg5_chk_8_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg5_chk_8_gpus.pt)
Average of last 10 checkpoints | [12b_avg10_chk_2_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg10_chk_2_gpus.pt) | [12b_avg10_chk_4_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg10_chk_4_gpus.pt) | [12b_avg10_chk_6_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg10_chk_6_gpus.pt) | [12b_avg10_chk_8_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg10_chk_8_gpus.pt)
**Generation Arguments**
Configuration | 2 32GB GPUs | 4 16GB GPUs | 6 12GB GPUs | 8 8GB GPUs
:--|:--|:--|:--|:--
`--pipeline-encoder-balance` | `[26]` | `[1,15,10]` | `[1,9,9,7]` | `[1,6,6,6,7]`
`--pipeline-encoder-devices` | `[0]` | `[0,1,0]` | `[0,1,2,0]` | `[0,4,5,1,0]`
`--pipeline-decoder-balance` | `[3,22,1]` | `[3,11,11,1]` | `[3,7,7,8,1]` | `[1,6,6,6,6,1]`
`--pipeline-decoder-devices` | `[0,1,0]` | `[0,2,3,0]` | `[0,3,4,5,0]` | `[0,2,6,7,3,0]`
## SentencePiece Model
```bash
wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model
```
## Generation with M2M-100
### Encode using our SentencePiece Model
Note: Install SentencePiece from [here](https://github.com/google/sentencepiece)
```bash
fairseq=/path/to/fairseq
cd $fairseq
sacrebleu --echo src -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.de
sacrebleu --echo ref -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.fr
wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model
for lang in de fr ; do
python scripts/spm_encode.py \
--model spm.128k.model \
--output_format=piece \
--inputs=raw_input.de-fr.${lang} \
--outputs=spm.de-fr.${lang}
done
```
### Binarization
```bash
wget https://dl.fbaipublicfiles.com/m2m_100/data_dict.128k.txt
fairseq-preprocess \
--source-lang de --target-lang fr \
--testpref spm.de-fr \
--thresholdsrc 0 --thresholdtgt 0 \
--destdir data_bin \
--srcdict data_dict.128k.txt --tgtdict data_dict.128k.txt
```
### Generation for the 12B model
Note that generation can currently be run using 2 32GB / 4 16GB / 6 12GB / 8 8GB GPUs, and the corresponding model checkpoints and pipeline arguments can be found in the [12B Model Section](#12b-model).
Generation on CPUs will be added in the future.
```bash
wget https://dl.fbaipublicfiles.com/m2m_100/model_dict.128k.txt
wget https://dl.fbaipublicfiles.com/m2m_100/language_pairs.txt
wget https://dl.fbaipublicfiles.com/m2m_100/12b_last_chk_4_gpus.pt
fairseq-generate \
data_bin \
--batch-size 1 \
--path 12b_last_chk_4_gpus.pt \
--fixed-dictionary model_dict.128k.txt \
-s de -t fr \
--remove-bpe 'sentencepiece' \
--beam 5 \
--task translation_multi_simple_epoch \
--lang-pairs language_pairs.txt \
--decoder-langtok --encoder-langtok src \
--gen-subset test \
--fp16 \
--dataset-impl mmap \
--distributed-world-size 1 --distributed-no-spawn \
--pipeline-model-parallel \
--pipeline-chunks 1 \
--pipeline-encoder-balance '[1,15,10]' \
--pipeline-encoder-devices '[0,1,0]' \
--pipeline-decoder-balance '[3,11,11,1]' \
--pipeline-decoder-devices '[0,2,3,0]' > gen_out
```
## Evaluation with M2M-100
### Tokenization
Note: Refer to tokenizers/README.md for more details on tokenization.
```bash
cd ${fairseq}/examples/m2m_100
cat ${fairseq}/gen_out | grep -P "^H" | sort -V | cut -f 3- | sh tok.sh fr > hyp
cat ${fairseq}/raw_input.de-fr.fr | sh tok.sh fr > ref
```
### BLEU
```bash
sacrebleu -tok 'none' ref < hyp
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment