Commit c394d7d1 authored by “change”'s avatar “change”
Browse files

init

parents
# Adaptive Span
Adaptive Span is a novel self-attention mechanism that can learn its optimal
attention span. This allows us to extend significantly the maximum context size
used in Transformer, while maintaining control over their memory footprint
and computational time. It uses the Truncated BPTT technique for training,
as in [transformerXL](https://github.com/pytorch/fairseq/blob/master/examples/truncated_bptt/README.md).
Adaptive Span was introduced by paper:
[Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799),
which achieved state-of-the-art language modeling results at the time of publication.
We manage to reproduce their result in fairseq and keep most of the
[original implementation](https://github.com/facebookresearch/adaptive-span) untouched.
You can refer to the their sweep file as well if any combination of hyperparameter is not clear.
##### 0. Setup
First you need to process the Enwik8 dataset, we use the pre-tokenized dataset
from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh).
You can download the dataset, and then run:
```bash
fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \
--validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \
--destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20
```
##### 1. Train a Adaptive Span model on Enwik8
We will train a 12-layer Adaptive Span model following the [hyperparameters
used in the original
paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
The following command assumes 4 GPUs, so that the total batch size is 64
sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs:
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
--user-dir examples/adaptive_span \
--data ~/data/enwik8/data-bin/ \
--fp16 --fp16-no-flatten-grads --max-update 600000 \
--task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \
--n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \
--attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \
--validate-interval-updates 1000 \
--lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \
--lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \
--seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07
```
This should land around 1.05 on validation, 1.03 on test. You can lower the
--aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc
improvement to the transformerXL baseline here.
If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
and simulate training on 4 GPUs.
You can also reproduce the transformerXL result on enwik8 using this code base.
It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh).
You can try by
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
--user-dir examples/truncated_bptt \
~/data/enwik8/data-bin/ \
--task truncated_bptt_lm --fp16 --max-update 400000 \
--tokens-per-sample 512 --arch transformer_xl --n-layer 12 \
--d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \
--dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \
--lr-scheduler cosine --warmup-updates 0 \
--lr 0.0 --lr 0.00025 --batch-size 15 \
--update-freq 1 --seed 2 --log-format json --log-interval 25 \
--fp16
```
##### 2. Evaluate
For Adaptive Span:
```bash
fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
--user-dir examples/adaptive_span \
--task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test
```
For Transformer-XL evaluation:
```bash
fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
--user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \
--tokens-per-sample 80 \
--model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \
--gen-subset valid
```
*Note:* During training the model saw 512 tokens of context
(``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation
settings from [the original
paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
# automatically import any Python files in the current directory
cur_dir = os.path.dirname(__file__)
for file in os.listdir(cur_dir):
path = os.path.join(cur_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
mod_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module(__name__ + "." + mod_name)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torch.optim import Adagrad
from fairseq.optim import LegacyFairseqOptimizer, register_optimizer
@register_optimizer("adagrad_with_grad_clip")
class FairseqAdagradWithGradClip(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
self._optimizer = AdagradWithGradClip(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
parser.add_argument('--adagrad-clip', default=0.0, type=float, metavar='D',
help='internal grad clip')
# fmt: on
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
"lr": self.args.lr[0],
"weight_decay": self.args.weight_decay,
"grad_clip": self.args.adagrad_clip,
}
@property
def supports_flat_params(self):
return False
def _clip_grad(clr, grad, group_grad_clip):
if group_grad_clip > 0:
norm = grad.norm(2).item()
if norm > group_grad_clip:
clr *= group_grad_clip / (norm + 1e-10)
return clr
class AdagradWithGradClip(Adagrad):
"""Adagrad algorithm with custom gradient clipping"""
def __init__(
self,
params,
lr=1e-2,
lr_decay=0,
weight_decay=0,
initial_accumulator_value=0,
grad_clip=0,
):
Adagrad.__init__(
self,
params,
lr=lr,
lr_decay=lr_decay,
weight_decay=weight_decay,
initial_accumulator_value=initial_accumulator_value,
)
self.defaults["grad_clip"] = grad_clip
self.param_groups[0].setdefault("grad_clip", grad_clip)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
state["step"] += 1
if group["weight_decay"] != 0:
if p.grad.data.is_sparse:
raise RuntimeError(
"weight_decay option is "
"not compatible with sparse "
"gradients"
)
grad = grad.add(group["weight_decay"], p.data)
clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"])
# clip
clr = _clip_grad(clr=clr, grad=grad, group_grad_clip=group["grad_clip"])
if grad.is_sparse:
# the update is non-linear so indices must be unique
grad = grad.coalesce()
grad_indices = grad._indices()
grad_values = grad._values()
size = grad.size()
def make_sparse(values):
constructor = grad.new
if grad_indices.dim() == 0 or values.dim() == 0:
return constructor().resize_as_(grad)
return constructor(grad_indices, values, size)
state["sum"].add_(make_sparse(grad_values.pow(2)))
std = state["sum"]._sparse_mask(grad)
std_values = std._values().sqrt_().add_(1e-10)
p.data.add_(-clr, make_sparse(grad_values / std_values))
else:
state["sum"].addcmul_(1, grad, grad)
std = state["sum"].sqrt().add_(1e-10)
p.data.addcdiv_(-clr, grad, std)
return loss
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class AdaptiveMask(nn.Module):
"""Soft masking function for adaptive size.
It masks out the last K values of an input. The masking value
goes from 1 to 0 gradually, so K can be learned with
back-propagation.
Args:
max_size: maximum size (i.e. input dimension)
ramp_size: size of the ramp going from 0 to 1
init_val: initial size proportion not to be masked out
shape: learn multiple sizes independent of each other
"""
def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)):
nn.Module.__init__(self)
self._max_size = max_size
self._ramp_size = ramp_size
self.current_val = nn.Parameter(torch.zeros(*shape) + init_val)
mask_template = torch.linspace(1 - max_size, 0, steps=max_size)
self.register_buffer("mask_template", mask_template)
def forward(self, x):
mask = self.mask_template.float() + self.current_val.float() * self._max_size
mask = mask / self._ramp_size + 1
mask = mask.clamp(0, 1)
if x.size(-1) < self._max_size:
# the input could have been trimmed beforehand to save computation
mask = mask.narrow(-1, self._max_size - x.size(-1), x.size(-1))
x = (x * mask).type_as(x)
return x
def get_current_max_size(self, include_ramp=True):
current_size = math.ceil(self.current_val.max().item() * self._max_size)
if include_ramp:
current_size += self._ramp_size
current_size = max(0, min(self._max_size, current_size))
return current_size
def get_current_avg_size(self, include_ramp=True):
current_size = math.ceil(
self.current_val.float().mean().item() * self._max_size
)
if include_ramp:
current_size += self._ramp_size
current_size = max(0, min(self._max_size, current_size))
return current_size
def clamp_param(self):
"""this need to be called after each update"""
self.current_val.data.clamp_(0, 1)
class AdaptiveSpan(nn.Module):
"""Adaptive attention span for Transformerself.
This module learns an attention span length from data for each
self-attention head.
Args:
attn_span: maximum attention span
adapt_span_loss: loss coefficient for the span length
adapt_span_ramp: length of the masking ramp
adapt_span_init: initial size ratio
adapt_span_cache: adapt cache size to reduce memory usage
"""
def __init__(
self,
attn_span,
adapt_span_ramp,
adapt_span_init,
n_head,
adapt_span_layer,
**kargs
):
nn.Module.__init__(self)
self._max_span = attn_span
self._n_head = n_head
self._adapt_span_layer = adapt_span_layer
if self._adapt_span_layer:
self._mask = AdaptiveMask(
max_size=self._max_span,
ramp_size=adapt_span_ramp,
init_val=adapt_span_init,
)
else:
self._mask = AdaptiveMask(
max_size=self._max_span,
ramp_size=adapt_span_ramp,
init_val=adapt_span_init,
shape=(n_head, 1, 1),
)
def forward(self, attn, normalize=True):
"""mask attention with the right span"""
# batch and head dimensions are merged together, so separate them first
self.clamp_param()
if self._adapt_span_layer:
attn = self._mask(attn)
else:
B = attn.size(0) # batch size
M = attn.size(1) # block size
attn = attn.reshape(B // self._n_head, self._n_head, M, -1)
attn = self._mask(attn)
attn = attn.view(B, M, -1)
return attn
def get_trim_len(self):
"""how much of memory can be trimmed to reduce computation"""
L = self._max_span
trim_len = min(L - 1, L - self._mask.get_current_max_size())
# too fine granularity might be bad for the memory management
trim_len = math.floor(trim_len / 64) * 64
return trim_len
def trim_memory(self, query, key, value, key_pe):
"""trim out unnecessary memory beforehand to reduce computation"""
trim_len = self.get_trim_len()
cache_size = key.size(1) - query.size(1)
trim_len_cache = trim_len - (self._max_span - cache_size)
if trim_len_cache > 0:
key = key[:, trim_len_cache:, :]
value = value[:, trim_len_cache:, :]
elif trim_len_cache < 0:
# cache is too short! this happens when validation resumes
# after a lot of updates.
key = F.pad(key, [0, 0, -trim_len_cache, 0])
value = F.pad(value, [0, 0, -trim_len_cache, 0])
if trim_len > 0:
if key_pe is not None:
key_pe = key_pe[:, :, trim_len:]
return key, value, key_pe
def get_cache_size(self):
"""determine how long the cache should be"""
trim_len = self.get_trim_len()
# give a buffer of 64 steps since a span might increase
# in future updates
return min(self._max_span, self._max_span - trim_len + 64)
def get_loss(self):
"""a loss term for regularizing the span length"""
return self._max_span * self._mask.current_val.float().mean()
def get_current_max_span(self):
return self._mask.get_current_max_size()
def get_current_avg_span(self):
return self._mask.get_current_avg_size()
def clamp_param(self):
self._mask.clamp_param()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from fairseq.criterions.cross_entropy import CrossEntropyCriterion
from fairseq.dataclass import FairseqDataclass
from omegaconf import II
@dataclass
class AdaptiveSpanCriterionConfig(FairseqDataclass):
sentence_avg: bool = II("optimization.sentence_avg")
@register_criterion("adaptive_span_loss", dataclass=AdaptiveSpanCriterionConfig)
class AdaptiveSpanCriterion(CrossEntropyCriterion):
def __init__(self, task, sentence_avg):
super().__init__(task, sentence_avg)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss here is summed, different from the adaptive span code
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
loss, aux_loss, avg_span, max_span = self.compute_loss(
model, net_output, sample, reduce=reduce
)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
loss /= sample_size
total_loss = loss + aux_loss
sample_size = 1
logging_output = {
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
"total_loss": total_loss.data,
"avg_span": avg_span * sample_size,
"max_span": max_span * sample_size,
}
return total_loss, sample_size, logging_output
def compute_loss(self, model, net_output, sample, reduce=True):
loss, _ = super().compute_loss(model, net_output, sample, reduce)
aux_loss = model.get_aux_loss()
avg_span = model.get_current_avg_span()
max_span = model.get_current_max_span()
return loss, aux_loss, avg_span, max_span
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
total_loss_sum = sum(log.get("total_loss", 0) for log in logging_outputs)
avg_span_sum = sum(log.get("avg_span", 0) for log in logging_outputs)
max_span_sum = sum(log.get("max_span", 0) for log in logging_outputs)
# we divide by log(2) to convert the loss from base e to base 2
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar("avg_span", avg_span_sum / sample_size, sample_size, round=3)
metrics.log_scalar("max_span", max_span_sum / sample_size, sample_size, round=3)
# total loss contains the L1 norm on adaptive-span
metrics.log_scalar(
"total_loss",
total_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.modules.layer_norm import LayerNorm
from .adaptive_span_attention import AdaptiveSpan
# Size notations:
# B = batch_size, H = d_model, M = block_size, L = attn_span
def _skew(X, pad_value):
"""shift every row 1 step to right"""
# X = B x M x L
B, M, L = X.size()
X = F.pad(X, (0, M + 1), value=pad_value) # B x M x (L+M+1)
X = X.view(B, -1) # B x ML+MM+M
X = X[:, :-M] # B x ML+MM
X = X.view(B, M, M + L) # B x M x L+M
return X
def _unskew(X):
"""reverse _skew operation"""
# X = B x M x L+M
B, M, L = X.size()
L -= M
X = X.view(B, -1) # B x ML+MM
X = F.pad(X, (0, M)) # B x ML+MM+M
X = X.view(B, M, M + L + 1) # B x M x L+M+1
X = X[:, :, :L] # B x M x L
return X
class SeqAttention(nn.Module):
"""Sequential self-attention layer.
Each token will attend to its previous fixed number of steps.
Note that attention doesn't include the current step itself.
"""
def __init__(self, d_model, n_head, attn_span, dropout, adapt_span_layer, **kargs):
nn.Module.__init__(self)
self.dropout = nn.Dropout(dropout)
self.d_model = d_model # size of a single head
self.attn_span = attn_span
self.adaptive_span = AdaptiveSpan(
attn_span=attn_span,
n_head=n_head,
adapt_span_layer=adapt_span_layer,
**kargs
)
def forward(self, query, key, value, key_pe):
# query size = B x M x H
# key, value sizes = B x (M+L) x H
key, value, key_pe = self.adaptive_span.trim_memory(query, key, value, key_pe)
# compute attention from context
# B x M (dest) x (M+L) (src)
attn_cont = torch.matmul(query, key.transpose(-1, -2))
attn_cont = _unskew(attn_cont) # B x M x L
# compute the effect of position embedding
attn_pos = torch.matmul(query, key_pe) # B x M x L_pos
attn = attn_cont + attn_pos
attn = attn / math.sqrt(self.d_model) # B x M X L_pos
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
# trim attention lengths according to the learned span
attn = self.adaptive_span(attn)
attn = self.dropout(attn) # B x M X L_pos
attn_cont = _skew(attn, 0) # B x M X (L+M)
out = torch.matmul(attn_cont, value) # B x M x H
return out
def get_cache_size(self):
return self.adaptive_span.get_cache_size()
class MultiHeadSeqAttention(nn.Module):
def __init__(self, d_model, n_head, **kargs):
nn.Module.__init__(self)
assert d_model % n_head == 0
self.n_head = n_head
self.head_dim = d_model // n_head
self.attn = SeqAttention(d_model=self.head_dim, n_head=n_head, **kargs)
self.proj_query = nn.Linear(d_model, d_model, bias=False)
nn.init.xavier_normal_(self.proj_query.weight)
self.proj_out = nn.Linear(d_model, d_model, bias=False)
nn.init.xavier_normal_(self.proj_out.weight)
self.proj_val = nn.Linear(d_model, d_model, bias=False)
nn.init.xavier_normal_(self.proj_val.weight)
self.proj_key = nn.Linear(d_model, d_model, bias=False)
nn.init.xavier_normal_(self.proj_key.weight)
def head_reshape(self, x):
K = self.n_head
D = self.head_dim
x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D
x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D
x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D
return x
def forward(self, query, key, value, key_pe):
B = query.size(0)
K = self.n_head
D = self.head_dim
M = query.size(1)
query = self.proj_query(query)
query = self.head_reshape(query)
value = self.proj_val(value)
value = self.head_reshape(value)
key = self.proj_key(key)
key = self.head_reshape(key)
out = self.attn(query, key, value, key_pe) # B_K x M x D
out = out.view(B, K, M, D) # B x K x M x D
out = out.transpose(1, 2).contiguous() # B x M x K x D
out = out.view(B, M, -1) # B x M x K_D
out = self.proj_out(out)
return out
class FeedForwardLayer(nn.Module):
def __init__(self, d_model, d_inner, dropout, **kargs):
nn.Module.__init__(self)
self.fc1 = nn.Linear(d_model, d_inner)
self.fc2 = nn.Linear(d_inner, d_model)
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
self.dropout = nn.Dropout(dropout)
def forward(self, h):
h1 = F.relu(self.fc1(h))
h1 = self.dropout(h1)
h2 = self.fc2(h1)
return h2
class TransformerSeqLayer(nn.Module):
def __init__(self, d_model, **kargs):
nn.Module.__init__(self)
self.attn = MultiHeadSeqAttention(d_model=d_model, **kargs)
self.norm1 = LayerNorm(d_model)
self.ff = FeedForwardLayer(d_model=d_model, **kargs)
self.norm2 = LayerNorm(d_model)
def forward(self, h, h_cache, key_pe):
# h = B x M x H
# h_cache = B x L x H
h_all = torch.cat([h_cache, h], dim=1) # B x (M+L) x H
attn_out = self.attn(h, h_all, h_all, key_pe)
h = self.norm1(h + attn_out) # B x M x H
if self.ff is not None:
ff_out = self.ff(h)
out = self.norm2(h + ff_out) # B x M x H
else:
out = h
return out
def get_cache_size(self):
return self.attn.attn.get_cache_size()
class TransformerSeq(nn.Module):
def __init__(
self,
vocab_size,
d_model,
n_head,
n_layer,
attn_span,
emb_dropout,
aux_loss_scaler,
adapt_span_layer,
**kargs
):
nn.Module.__init__(self)
# token embeddings
self.in_emb = nn.Embedding(vocab_size, d_model)
nn.init.normal_(self.in_emb.weight, mean=0, std=d_model ** -0.5)
self.out_emb = nn.Linear(d_model, vocab_size)
self.aux_loss_scaler = aux_loss_scaler
if emb_dropout > 0:
self.emb_dropout = nn.Dropout(emb_dropout)
else:
self.emb_dropout = None
# position embeddings
self.key_pe = nn.Parameter(torch.randn(1, d_model // n_head, attn_span))
self.layers = nn.ModuleList()
self.layers.extend(
TransformerSeqLayer(
d_model=d_model,
n_head=n_head,
attn_span=attn_span,
adapt_span_layer=adapt_span_layer,
**kargs
)
for _ in range(n_layer)
)
def forward(self, x, h_cache, target=None):
# x size = B x M
block_size = x.size(1)
h = self.in_emb(x) # B x M x H
if self.emb_dropout is not None:
h = self.emb_dropout(h)
h_cache_next = []
for l, layer in enumerate(self.layers):
cache_size = layer.attn.attn.get_cache_size()
if cache_size > block_size:
h_cache_next_l = torch.cat(
[h_cache[l][:, -cache_size + block_size :, :], h], dim=1
).detach()
else:
h_cache_next_l = h[:, -cache_size:, :].detach()
h_cache_next.append(h_cache_next_l)
h = layer(h, h_cache[l], self.key_pe) # B x M x H
if self.emb_dropout is not None:
h = self.emb_dropout(h)
out = F.log_softmax(self.out_emb(h).float(), dim=-1).type_as(h)
dummy_loss = None
return out, h_cache_next, dummy_loss
def get_aux_loss(self):
loss = 0.0
for layer in self.layers:
loss += layer.attn.attn.adaptive_span.get_loss()
return self.aux_loss_scaler * loss
def get_current_max_span(self):
max_span = 0.0
for layer in self.layers:
max_span = max(
max_span, layer.attn.attn.adaptive_span.get_current_max_span()
)
return max_span
def get_current_avg_span(self):
avg_span = 0.0
for layer in self.layers:
avg_span += layer.attn.attn.adaptive_span.get_current_avg_span()
return avg_span / len(self.layers)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Dict, List, Optional
import torch
from fairseq.dataclass import FairseqDataclass
from fairseq.models import (
FairseqIncrementalDecoder,
FairseqLanguageModel,
register_model,
)
from .adaptive_span_model import TransformerSeq as AdaptiveSpanTransformerModel
logger = logging.getLogger(__name__)
@dataclass
class AdaptiveSpanSmallConfig(FairseqDataclass):
# defaults come from https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8_small.sh
vocab_size: int = 50
d_model: int = 256
n_head: int = 4
d_inner: int = 1024
n_layer: int = 8
attn_span: int = 1024
dropout: float = 0.0
emb_dropout: float = 0.0
adapt_span_ramp: int = 32
adapt_span_init: float = 0.0
aux_loss_scaler: float = 0.000002
adapt_span_layer: bool = False
@register_model("adaptive_span", dataclass=AdaptiveSpanSmallConfig)
class AdaptiveSpanTransformer(FairseqLanguageModel):
@classmethod
def build_model(cls, cfg: AdaptiveSpanSmallConfig, task):
return cls(AdaptiveSpanDecoder(cfg, task))
def get_aux_loss(self):
return self.decoder.get_aux_loss()
def get_current_max_span(self):
return self.decoder.get_current_max_span()
def get_current_avg_span(self):
return self.decoder.get_current_avg_span()
class AdaptiveSpanDecoder(FairseqIncrementalDecoder):
def __init__(self, cfg, task):
super().__init__(task.target_dictionary)
self.config = cfg
config = AdaptiveSpanSmallConfig(
vocab_size=len(task.target_dictionary),
d_model=cfg.d_model,
n_head=cfg.n_head,
d_inner=cfg.d_inner,
n_layer=cfg.n_layer,
attn_span=cfg.attn_span,
dropout=cfg.dropout,
emb_dropout=cfg.emb_dropout,
adapt_span_ramp=cfg.adapt_span_ramp,
adapt_span_init=cfg.adapt_span_init,
aux_loss_scaler=cfg.aux_loss_scaler,
adapt_span_layer=cfg.adapt_span_layer,
)
logger.info(config)
self.model = AdaptiveSpanTransformerModel(**config.__dict__)
self._mems = None
def forward(
self,
src_tokens,
incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
encoder_out=None,
):
bsz = src_tokens.size(0)
if incremental_state is not None: # used during inference
mems = self.get_incremental_state("mems")
src_tokens = src_tokens[:, -1:] # only keep the most recent token
else:
mems = self._mems
if mems is None:
# first time init
mems = self.init_hid_cache(bsz)
output = self.model(x=src_tokens, h_cache=mems,)
if incremental_state is not None:
self.set_incremental_state(incremental_state, "mems", output[1])
else:
self._mems = output[1]
return (output[0],)
def max_positions(self):
return self.config.attn_span
def init_hid_cache(self, batch_sz):
hid = []
for layer in self.model.layers:
param = next(self.model.parameters())
h = torch.zeros(
batch_sz,
layer.get_cache_size(),
self.config.d_model,
dtype=param.dtype,
device=param.device,
)
hid.append(h)
return hid
def get_aux_loss(self):
return self.model.get_aux_loss()
def get_current_max_span(self):
return self.model.get_current_max_span()
def get_current_avg_span(self):
return self.model.get_current_avg_span()
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]],
new_order: torch.Tensor,
):
"""Reorder incremental state.
This will be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the selection of beams.
"""
raise NotImplementedError("This is required for generation/beam search")
# mems = self.get_incremental_state(incremental_state, "mems")
# if mems is not None:
# new_mems = [mems_i.index_select(1, new_order) for mems_i in mems]
# self.set_incremental_state(incremental_state, "mems", new_mems)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
import torch
from fairseq import utils
from fairseq.data import (
Dictionary,
TokenBlockDataset,
data_utils,
iterators,
)
from fairseq.dataclass import FairseqDataclass
from fairseq.distributed import utils as dist_utils
from fairseq.tasks import FairseqTask, register_task
from omegaconf import II
logger = logging.getLogger(__name__)
@dataclass
class TruncatedBPTTLMConfig(FairseqDataclass):
data: str = field(default="???", metadata={"help": "path to data directory"})
tokens_per_sample: int = field(
default=1024,
metadata={"help": "max number of tokens per sequence"},
)
batch_size: int = II("dataset.batch_size")
# Some models use *max_target_positions* to know how many positional
# embeddings to learn. We use II(...) to make it default to
# *tokens_per_sample*, but in principle there could be more positional
# embeddings than tokens in a single batch. This may also be irrelevant for
# custom model implementations.
max_target_positions: int = II("task.tokens_per_sample")
# these will be populated automatically if not provided
data_parallel_rank: Optional[int] = None
data_parallel_size: Optional[int] = None
@register_task("truncated_bptt_lm", dataclass=TruncatedBPTTLMConfig)
class TruncatedBPTTLMTask(FairseqTask):
def __init__(self, cfg: TruncatedBPTTLMConfig):
super().__init__(cfg)
if cfg.data_parallel_rank is None or cfg.data_parallel_size is None:
if torch.distributed.is_initialized():
cfg.data_parallel_rank = dist_utils.get_data_parallel_rank()
cfg.data_parallel_size = dist_utils.get_data_parallel_world_size()
else:
cfg.data_parallel_rank = 0
cfg.data_parallel_size = 1
# load the dictionary
paths = utils.split_paths(cfg.data)
assert len(paths) > 0
self.dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
logger.info("dictionary: {} types".format(len(self.dictionary)))
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split (e.g., train, valid, test)"""
# support sharded datasets
paths = utils.split_paths(self.cfg.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)
# each element of *data* will be a tensorized line from the original
# text dataset, similar to ``open(split_path).readlines()``
data = data_utils.load_indexed_dataset(
split_path, self.dictionary, combine=combine
)
if data is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)
# this is similar to ``data.view(-1).split(tokens_per_sample)``
data = TokenBlockDataset(
data,
data.sizes,
block_size=self.cfg.tokens_per_sample,
pad=None, # unused
eos=None, # unused
break_mode="none",
)
self.datasets[split] = TruncatedBPTTDataset(
data=data,
bsz_per_shard=self.cfg.batch_size,
shard_id=self.cfg.data_parallel_rank,
num_shards=self.cfg.data_parallel_size,
)
def dataset(self, split):
return self.datasets[split]
def get_batch_iterator(
self, dataset, num_workers=0, epoch=1, data_buffer_size=0, **kwargs
):
return iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=self._collate_fn,
num_workers=num_workers,
epoch=epoch,
buffer_size=data_buffer_size,
# we don't use the batching functionality from EpochBatchIterator;
# instead every item in *dataset* is a whole batch
batch_sampler=[[i] for i in range(len(dataset))],
disable_shuffling=True,
)
def _collate_fn(self, items: List[List[torch.Tensor]]):
# we don't use fairseq's batching functionality, so we expect a single
# Tensor of type List[torch.Tensor]
assert len(items) == 1
# item will have shape B x T (the last batch may have length < T)
id, item = items[0]
item = data_utils.collate_tokens(item, pad_idx=self.source_dictionary.pad())
B, T = item.size()
# shift item one position over and append a padding token for the target
target = torch.nn.functional.pad(
item[:, 1:], (0, 1, 0, 0), value=self.target_dictionary.pad()
)
# fairseq expects batches to have the following structure
return {
"id": torch.tensor([id]*item.size(0)),
"net_input": {
"src_tokens": item,
},
"target": target,
"nsentences": item.size(0),
"ntokens": item.numel(),
}
def build_dataset_for_inference(
self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs
) -> torch.utils.data.Dataset:
eos = self.source_dictionary.eos()
dataset = TokenBlockDataset(
src_tokens,
src_lengths,
block_size=None, # ignored for "eos" break mode
pad=self.source_dictionary.pad(),
eos=eos,
break_mode="eos",
)
class Dataset(torch.utils.data.Dataset):
def __getitem__(self, i):
item = dataset[i]
if item[-1] == eos:
# remove eos to support generating with a prefix
item = item[:-1]
return (i, [item])
def __len__(self):
return len(dataset)
return Dataset()
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
with torch.no_grad():
if constraints is not None:
raise NotImplementedError
# SequenceGenerator doesn't use *src_tokens* directly, we need to
# pass the *prefix_tokens* argument instead.
if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
prefix_tokens = sample["net_input"]["src_tokens"]
# begin generation with the end-of-sentence token
bos_token = self.source_dictionary.eos()
return generator.generate(
models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token
)
def eval_lm_dataloader(
self,
dataset,
max_tokens: Optional[int] = 36000,
batch_size: Optional[int] = None,
max_positions: Optional[int] = None,
num_shards: int = 1,
shard_id: int = 0,
num_workers: int = 1,
data_buffer_size: int = 10,
context_window: int = 0,
):
if context_window > 0:
raise NotImplementedError(
"Transformer-XL doesn't need --context-window, try "
"--model-overrides '{\"mem_len\":42}' instead "
)
return self.get_batch_iterator(
dataset=dataset,
max_tokens=max_tokens,
max_sentences=batch_size,
max_positions=max_positions,
ignore_invalid_inputs=True,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
data_buffer_size=data_buffer_size,
).next_epoch_itr(shuffle=False)
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary
class TruncatedBPTTDataset(torch.utils.data.Dataset):
def __init__(
self,
data: List[torch.Tensor], # ordered list of items
bsz_per_shard, # number of items processed per GPUs per forward
shard_id, # current GPU ID
num_shards, # number of GPUs
):
super().__init__()
self.data = data
def batchify(data, bsz):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)
# Evenly divide the data across the bsz batches.
data = data.view(bsz, -1).contiguous()
return data
# total number of sequences processed by all GPUs in each forward pass
global_batch_size = bsz_per_shard * num_shards
"""
With a 16 item dataset, bsz_per_shard=2 and num_shards=3,
*indices* might look like:
indices = [[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9],
[10, 11]]
The size of the TruncatedBPTTDataset instance will be 2,
and shard 1 will see items:
[(0, [data[4], data[6]]),
(1, [data[5], data[7]])]
"""
indices = batchify(torch.arange(len(data)), global_batch_size)
assert indices.size(0) == global_batch_size
self.my_indices = indices[
shard_id * bsz_per_shard : (shard_id + 1) * bsz_per_shard
]
assert self.my_indices.size(0) == bsz_per_shard
def __len__(self):
return self.my_indices.size(1)
def __getitem__(self, i) -> Tuple[int, List[torch.Tensor]]:
return (i, [self.data[idx] for idx in self.my_indices[:, i]])
# Understanding Back-Translation at Scale (Edunov et al., 2018)
This page includes pre-trained models from the paper [Understanding Back-Translation at Scale (Edunov et al., 2018)](https://arxiv.org/abs/1808.09381).
## Pre-trained models
Model | Description | Dataset | Download
---|---|---|---
`transformer.wmt18.en-de` | Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381)) <br> WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz) <br> See NOTE in the archive
## Example usage (torch.hub)
We require a few additional Python dependencies for preprocessing:
```bash
pip install subword_nmt sacremoses
```
Then to generate translations from the full model ensemble:
```python
import torch
# List available models
torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt18.en-de', ... ]
# Load the WMT'18 En-De ensemble
en2de_ensemble = torch.hub.load(
'pytorch/fairseq', 'transformer.wmt18.en-de',
checkpoint_file='wmt18.model1.pt:wmt18.model2.pt:wmt18.model3.pt:wmt18.model4.pt:wmt18.model5.pt',
tokenizer='moses', bpe='subword_nmt')
# The ensemble contains 5 models
len(en2de_ensemble.models)
# 5
# Translate
en2de_ensemble.translate('Hello world!')
# 'Hallo Welt!'
```
## Training your own model (WMT'18 English-German)
The following instructions can be adapted to reproduce the models from the paper.
#### Step 1. Prepare parallel data and optionally train a baseline (English-German) model
First download and preprocess the data:
```bash
# Download and prepare the data
cd examples/backtranslation/
bash prepare-wmt18en2de.sh
cd ../..
# Binarize the data
TEXT=examples/backtranslation/wmt18_en_de
fairseq-preprocess \
--joined-dictionary \
--source-lang en --target-lang de \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/wmt18_en_de --thresholdtgt 0 --thresholdsrc 0 \
--workers 20
# Copy the BPE code into the data-bin directory for future use
cp examples/backtranslation/wmt18_en_de/code data-bin/wmt18_en_de/code
```
(Optionally) Train a baseline model (English-German) using just the parallel data:
```bash
CHECKPOINT_DIR=checkpoints_en_de_parallel
fairseq-train --fp16 \
data-bin/wmt18_en_de \
--source-lang en --target-lang de \
--arch transformer_wmt_en_de_big --share-all-embeddings \
--dropout 0.3 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
--max-tokens 3584 --update-freq 16 \
--max-update 30000 \
--save-dir $CHECKPOINT_DIR
# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
# different number of GPUs.
```
Average the last 10 checkpoints:
```bash
python scripts/average_checkpoints.py \
--inputs $CHECKPOINT_DIR \
--num-epoch-checkpoints 10 \
--output $CHECKPOINT_DIR/checkpoint.avg10.pt
```
Evaluate BLEU:
```bash
# tokenized BLEU on newstest2017:
bash examples/backtranslation/tokenized_bleu.sh \
wmt17 \
en-de \
data-bin/wmt18_en_de \
data-bin/wmt18_en_de/code \
$CHECKPOINT_DIR/checkpoint.avg10.pt
# BLEU4 = 29.57, 60.9/35.4/22.9/15.5 (BP=1.000, ratio=1.014, syslen=63049, reflen=62152)
# compare to 29.46 in Table 1, which is also for tokenized BLEU
# generally it's better to report (detokenized) sacrebleu though:
bash examples/backtranslation/sacrebleu.sh \
wmt17 \
en-de \
data-bin/wmt18_en_de \
data-bin/wmt18_en_de/code \
$CHECKPOINT_DIR/checkpoint.avg10.pt
# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 29.0 60.6/34.7/22.4/14.9 (BP = 1.000 ratio = 1.013 hyp_len = 62099 ref_len = 61287)
```
#### Step 2. Back-translate monolingual German data
Train a reverse model (German-English) to do the back-translation:
```bash
CHECKPOINT_DIR=checkpoints_de_en_parallel
fairseq-train --fp16 \
data-bin/wmt18_en_de \
--source-lang de --target-lang en \
--arch transformer_wmt_en_de_big --share-all-embeddings \
--dropout 0.3 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
--max-tokens 3584 --update-freq 16 \
--max-update 30000 \
--save-dir $CHECKPOINT_DIR
# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
# different number of GPUs.
```
Let's evaluate the back-translation (BT) model to make sure it is well trained:
```bash
bash examples/backtranslation/sacrebleu.sh \
wmt17 \
de-en \
data-bin/wmt18_en_de \
data-bin/wmt18_en_de/code \
$CHECKPOINT_DIR/checkpoint_best.py
# BLEU+case.mixed+lang.de-en+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 34.9 66.9/41.8/28.5/19.9 (BP = 0.983 ratio = 0.984 hyp_len = 63342 ref_len = 64399)
# compare to the best system from WMT'17 which scored 35.1: http://matrix.statmt.org/matrix/systems_list/1868
```
Next prepare the monolingual data:
```bash
# Download and prepare the monolingual data
# By default the script samples 25M monolingual sentences, which after
# deduplication should be just over 24M sentences. These are split into 25
# shards, each with 1M sentences (except for the last shard).
cd examples/backtranslation/
bash prepare-de-monolingual.sh
cd ../..
# Binarize each shard of the monolingual data
TEXT=examples/backtranslation/wmt18_de_mono
for SHARD in $(seq -f "%02g" 0 24); do \
fairseq-preprocess \
--only-source \
--source-lang de --target-lang en \
--joined-dictionary \
--srcdict data-bin/wmt18_en_de/dict.de.txt \
--testpref $TEXT/bpe.monolingual.dedup.${SHARD} \
--destdir data-bin/wmt18_de_mono/shard${SHARD} \
--workers 20; \
cp data-bin/wmt18_en_de/dict.en.txt data-bin/wmt18_de_mono/shard${SHARD}/; \
done
```
Now we're ready to perform back-translation over the monolingual data. The
following command generates via sampling, but it's possible to use greedy
decoding (`--beam 1`), beam search (`--beam 5`),
top-k sampling (`--sampling --beam 1 --sampling-topk 10`), etc.:
```bash
mkdir backtranslation_output
for SHARD in $(seq -f "%02g" 0 24); do \
fairseq-generate --fp16 \
data-bin/wmt18_de_mono/shard${SHARD} \
--path $CHECKPOINT_DIR/checkpoint_best.pt \
--skip-invalid-size-inputs-valid-test \
--max-tokens 4096 \
--sampling --beam 1 \
> backtranslation_output/sampling.shard${SHARD}.out; \
done
```
After BT, use the `extract_bt_data.py` script to re-combine the shards, extract
the back-translations and apply length ratio filters:
```bash
python examples/backtranslation/extract_bt_data.py \
--minlen 1 --maxlen 250 --ratio 1.5 \
--output backtranslation_output/bt_data --srclang en --tgtlang de \
backtranslation_output/sampling.shard*.out
# Ensure lengths are the same:
# wc -l backtranslation_output/bt_data.{en,de}
# 21795614 backtranslation_output/bt_data.en
# 21795614 backtranslation_output/bt_data.de
# 43591228 total
```
Binarize the filtered BT data and combine it with the parallel data:
```bash
TEXT=backtranslation_output
fairseq-preprocess \
--source-lang en --target-lang de \
--joined-dictionary \
--srcdict data-bin/wmt18_en_de/dict.en.txt \
--trainpref $TEXT/bt_data \
--destdir data-bin/wmt18_en_de_bt \
--workers 20
# We want to train on the combined data, so we'll symlink the parallel + BT data
# in the wmt18_en_de_para_plus_bt directory. We link the parallel data as "train"
# and the BT data as "train1", so that fairseq will combine them automatically
# and so that we can use the `--upsample-primary` option to upsample the
# parallel data (if desired).
PARA_DATA=$(readlink -f data-bin/wmt18_en_de)
BT_DATA=$(readlink -f data-bin/wmt18_en_de_bt)
COMB_DATA=data-bin/wmt18_en_de_para_plus_bt
mkdir -p $COMB_DATA
for LANG in en de; do \
ln -s ${PARA_DATA}/dict.$LANG.txt ${COMB_DATA}/dict.$LANG.txt; \
for EXT in bin idx; do \
ln -s ${PARA_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train.en-de.$LANG.$EXT; \
ln -s ${BT_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train1.en-de.$LANG.$EXT; \
ln -s ${PARA_DATA}/valid.en-de.$LANG.$EXT ${COMB_DATA}/valid.en-de.$LANG.$EXT; \
ln -s ${PARA_DATA}/test.en-de.$LANG.$EXT ${COMB_DATA}/test.en-de.$LANG.$EXT; \
done; \
done
```
#### 3. Train an English-German model over the combined parallel + BT data
Finally we can train a model over the parallel + BT data:
```bash
CHECKPOINT_DIR=checkpoints_en_de_parallel_plus_bt
fairseq-train --fp16 \
data-bin/wmt18_en_de_para_plus_bt \
--upsample-primary 16 \
--source-lang en --target-lang de \
--arch transformer_wmt_en_de_big --share-all-embeddings \
--dropout 0.3 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr 0.0007 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
--max-tokens 3584 --update-freq 16 \
--max-update 100000 \
--save-dir $CHECKPOINT_DIR
# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
# different number of GPUs.
```
Average the last 10 checkpoints:
```bash
python scripts/average_checkpoints.py \
--inputs $CHECKPOINT_DIR \
--num-epoch-checkpoints 10 \
--output $CHECKPOINT_DIR/checkpoint.avg10.pt
```
Evaluate BLEU:
```bash
# tokenized BLEU on newstest2017:
bash examples/backtranslation/tokenized_bleu.sh \
wmt17 \
en-de \
data-bin/wmt18_en_de \
data-bin/wmt18_en_de/code \
$CHECKPOINT_DIR/checkpoint.avg10.pt
# BLEU4 = 32.35, 64.4/38.9/26.2/18.3 (BP=0.977, ratio=0.977, syslen=60729, reflen=62152)
# compare to 32.35 in Table 1, which is also for tokenized BLEU
# generally it's better to report (detokenized) sacrebleu:
bash examples/backtranslation/sacrebleu.sh \
wmt17 \
en-de \
data-bin/wmt18_en_de \
data-bin/wmt18_en_de/code \
$CHECKPOINT_DIR/checkpoint.avg10.pt
# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 31.5 64.3/38.2/25.6/17.6 (BP = 0.971 ratio = 0.971 hyp_len = 59515 ref_len = 61287)
```
## Citation
```bibtex
@inproceedings{edunov2018backtranslation,
title = {Understanding Back-Translation at Scale},
author = {Edunov, Sergey and Ott, Myle and Auli, Michael and Grangier, David},
booktitle = {Conference of the Association for Computational Linguistics (ACL)},
year = 2018,
}
```
#!/usr/bin/python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import fileinput
import hashlib
import sys
from multiprocessing import Pool
def get_hashes_and_lines(raw_line):
hash = hashlib.md5(raw_line).hexdigest()
return hash, raw_line
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--workers", type=int, default=10)
parser.add_argument("files", nargs="*", help="input files")
args = parser.parse_args()
seen = set()
with fileinput.input(args.files, mode="rb") as h:
pool = Pool(args.workers)
results = pool.imap_unordered(get_hashes_and_lines, h, 1000)
for i, (hash, raw_line) in enumerate(results):
if hash not in seen:
seen.add(hash)
sys.stdout.buffer.write(raw_line)
if i % 1000000 == 0:
print(i, file=sys.stderr, end="", flush=True)
elif i % 100000 == 0:
print(".", file=sys.stderr, end="", flush=True)
print(file=sys.stderr, flush=True)
if __name__ == "__main__":
main()
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import fileinput
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser(
description=(
"Extract back-translations from the stdout of fairseq-generate. "
"If there are multiply hypotheses for a source, we only keep the first one. "
)
)
parser.add_argument("--output", required=True, help="output prefix")
parser.add_argument(
"--srclang", required=True, help="source language (extracted from H-* lines)"
)
parser.add_argument(
"--tgtlang", required=True, help="target language (extracted from S-* lines)"
)
parser.add_argument("--minlen", type=int, help="min length filter")
parser.add_argument("--maxlen", type=int, help="max length filter")
parser.add_argument("--ratio", type=float, help="ratio filter")
parser.add_argument("files", nargs="*", help="input files")
args = parser.parse_args()
def validate(src, tgt):
srclen = len(src.split(" ")) if src != "" else 0
tgtlen = len(tgt.split(" ")) if tgt != "" else 0
if (
(args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
or (
args.maxlen is not None
and (srclen > args.maxlen or tgtlen > args.maxlen)
)
or (
args.ratio is not None
and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
)
):
return False
return True
def safe_index(toks, index, default):
try:
return toks[index]
except IndexError:
return default
with open(args.output + "." + args.srclang, "w") as src_h, open(
args.output + "." + args.tgtlang, "w"
) as tgt_h:
for line in tqdm(fileinput.input(args.files)):
if line.startswith("S-"):
tgt = safe_index(line.rstrip().split("\t"), 1, "")
elif line.startswith("H-"):
if tgt is not None:
src = safe_index(line.rstrip().split("\t"), 2, "")
if validate(src, tgt):
print(src, file=src_h)
print(tgt, file=tgt_h)
tgt = None
if __name__ == "__main__":
main()
#!/bin/bash
SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
BPEROOT=subword-nmt/subword_nmt
BPE_CODE=wmt18_en_de/code
SUBSAMPLE_SIZE=25000000
LANG=de
OUTDIR=wmt18_${LANG}_mono
orig=orig
tmp=$OUTDIR/tmp
mkdir -p $OUTDIR $tmp
URLS=(
"http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2007.de.shuffled.gz"
"http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2008.de.shuffled.gz"
"http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2009.de.shuffled.gz"
"http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2010.de.shuffled.gz"
"http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2011.de.shuffled.gz"
"http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.de.shuffled.gz"
"http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.de.shuffled.gz"
"http://www.statmt.org/wmt15/training-monolingual-news-crawl-v2/news.2014.de.shuffled.v2.gz"
"http://data.statmt.org/wmt16/translation-task/news.2015.de.shuffled.gz"
"http://data.statmt.org/wmt17/translation-task/news.2016.de.shuffled.gz"
"http://data.statmt.org/wmt18/translation-task/news.2017.de.shuffled.deduped.gz"
)
FILES=(
"news.2007.de.shuffled.gz"
"news.2008.de.shuffled.gz"
"news.2009.de.shuffled.gz"
"news.2010.de.shuffled.gz"
"news.2011.de.shuffled.gz"
"news.2012.de.shuffled.gz"
"news.2013.de.shuffled.gz"
"news.2014.de.shuffled.v2.gz"
"news.2015.de.shuffled.gz"
"news.2016.de.shuffled.gz"
"news.2017.de.shuffled.deduped.gz"
)
cd $orig
for ((i=0;i<${#URLS[@]};++i)); do
file=${FILES[i]}
if [ -f $file ]; then
echo "$file already exists, skipping download"
else
url=${URLS[i]}
wget "$url"
fi
done
cd ..
if [ -f $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
echo "found monolingual sample, skipping shuffle/sample/tokenize"
else
gzip -c -d -k $(for FILE in "${FILES[@]}"; do echo $orig/$FILE; done) \
| shuf -n $SUBSAMPLE_SIZE \
| perl $NORM_PUNC $LANG \
| perl $REM_NON_PRINT_CHAR \
| perl $TOKENIZER -threads 8 -a -l $LANG \
> $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG}
fi
if [ -f $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
echo "found BPE monolingual sample, skipping BPE step"
else
python $BPEROOT/apply_bpe.py -c $BPE_CODE \
< $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} \
> $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG}
fi
if [ -f $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} ]; then
echo "found deduplicated monolingual sample, skipping deduplication step"
else
python deduplicate_lines.py $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} \
> $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG}
fi
if [ -f $OUTDIR/bpe.monolingual.dedup.00.de ]; then
echo "found sharded data, skipping sharding step"
else
split --lines 1000000 --numeric-suffixes \
--additional-suffix .${LANG} \
$tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} \
$OUTDIR/bpe.monolingual.dedup.
fi
#!/bin/bash
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
git clone https://github.com/rsennrich/subword-nmt.git
SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
BPEROOT=subword-nmt/subword_nmt
BPE_TOKENS=32000
URLS=(
"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
"http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
"http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz"
"http://data.statmt.org/wmt18/translation-task/rapid2016.tgz"
"http://data.statmt.org/wmt17/translation-task/dev.tgz"
"http://statmt.org/wmt14/test-full.tgz"
)
FILES=(
"training-parallel-europarl-v7.tgz"
"training-parallel-commoncrawl.tgz"
"training-parallel-nc-v13.tgz"
"rapid2016.tgz"
"dev.tgz"
"test-full.tgz"
)
CORPORA=(
"training/europarl-v7.de-en"
"commoncrawl.de-en"
"training-parallel-nc-v13/news-commentary-v13.de-en"
"rapid2016.de-en"
)
if [ ! -d "$SCRIPTS" ]; then
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
exit 1
fi
OUTDIR=wmt18_en_de
src=en
tgt=de
lang=en-de
prep=$OUTDIR
tmp=$prep/tmp
orig=orig
mkdir -p $orig $tmp $prep
cd $orig
for ((i=0;i<${#URLS[@]};++i)); do
file=${FILES[i]}
if [ -f $file ]; then
echo "$file already exists, skipping download"
else
url=${URLS[i]}
wget "$url"
if [ -f $file ]; then
echo "$url successfully downloaded."
else
echo "$url not successfully downloaded."
exit 1
fi
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
fi
fi
done
cd ..
echo "pre-processing train data..."
for l in $src $tgt; do
rm $tmp/train.tags.$lang.tok.$l
for f in "${CORPORA[@]}"; do
cat $orig/$f.$l | \
perl $NORM_PUNC $l | \
perl $REM_NON_PRINT_CHAR | \
perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l
done
done
echo "pre-processing test data..."
for l in $src $tgt; do
if [ "$l" == "$src" ]; then
t="src"
else
t="ref"
fi
grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
sed -e 's/<seg id="[0-9]*">\s*//g' | \
sed -e 's/\s*<\/seg>\s*//g' | \
sed -e "s/\’/\'/g" | \
perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l
echo ""
done
echo "splitting train and valid..."
for l in $src $tgt; do
awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
done
TRAIN=$tmp/train.de-en
BPE_CODE=$prep/code
rm -f $TRAIN
for l in $src $tgt; do
cat $tmp/train.$l >> $TRAIN
done
echo "learn_bpe.py on ${TRAIN}..."
python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
for L in $src $tgt; do
for f in train.$L valid.$L test.$L; do
echo "apply_bpe.py to ${f}..."
python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f
done
done
perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
for L in $src $tgt; do
cp $tmp/bpe.test.$L $prep/test.$L
done
#!/bin/bash
if [ $# -ne 5 ]; then
echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
exit
fi
DATASET=$1
LANGPAIR=$2
DATABIN=$3
BPECODE=$4
MODEL=$5
SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
BPEROOT=examples/backtranslation/subword-nmt/subword_nmt
if [ ! -e $BPEROOT ]; then
BPEROOT=subword-nmt/subword_nmt
if [ ! -e $BPEROOT ]; then
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
git clone https://github.com/rsennrich/subword-nmt.git
fi
fi
sacrebleu -t $DATASET -l $LANGPAIR --echo src \
| sacremoses tokenize -a -l $SRCLANG -q \
| python $BPEROOT/apply_bpe.py -c $BPECODE \
| fairseq-interactive $DATABIN --path $MODEL \
-s $SRCLANG -t $TGTLANG \
--beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
| grep ^H- | cut -f 3- \
| sacremoses detokenize -l $TGTLANG -q \
| sacrebleu -t $DATASET -l $LANGPAIR
#!/bin/bash
if [ $# -ne 5 ]; then
echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
exit
fi
DATASET=$1
LANGPAIR=$2
DATABIN=$3
BPECODE=$4
MODEL=$5
SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
BPEROOT=examples/backtranslation/subword-nmt/subword_nmt
if [ ! -e $BPEROOT ]; then
BPEROOT=subword-nmt/subword_nmt
if [ ! -e $BPEROOT ]; then
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
git clone https://github.com/rsennrich/subword-nmt.git
fi
fi
TMP_REF=$(mktemp)
sacrebleu -t $DATASET -l $LANGPAIR --echo ref -q \
| sacremoses normalize -l $TGTLANG -q \
| sacremoses tokenize -a -l $TGTLANG -q \
> $TMP_REF
sacrebleu -t $DATASET -l $LANGPAIR --echo src -q \
| sacremoses normalize -l $SRCLANG -q \
| sacremoses tokenize -a -l $SRCLANG -q \
| python $BPEROOT/apply_bpe.py -c $BPECODE \
| fairseq-interactive $DATABIN --path $MODEL \
-s $SRCLANG -t $TGTLANG \
--beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
| grep ^H- | cut -f 3- \
| fairseq-score --ref $TMP_REF
rm -f $TMP_REF
# Fine-tuning BART on GLUE tasks
### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands:
```bash
wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
python download_glue_data.py --data_dir glue_data --tasks all
```
### 2) Preprocess GLUE task data (same as RoBERTa):
```bash
./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>
```
`glue_task_name` is one of the following:
`{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}`
Use `ALL` for preprocessing all the glue tasks.
### 3) Fine-tuning on GLUE task:
Example fine-tuning cmd for `RTE` task
```bash
TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16
WARMUP_UPDATES=61 # 6 percent of the number of updates
LR=1e-05 # Peak LR for polynomial LR scheduler.
NUM_CLASSES=2
MAX_SENTENCES=16 # Batch size.
BART_PATH=/path/to/bart/model.pt
CUDA_VISIBLE_DEVICES=0,1 fairseq-train RTE-bin/ \
--restore-file $BART_PATH \
--batch-size $MAX_SENTENCES \
--max-tokens 4400 \
--task sentence_prediction \
--add-prev-output-tokens \
--layernorm-embedding \
--share-all-embeddings \
--share-decoder-input-output-embed \
--reset-optimizer --reset-dataloader --reset-meters \
--required-batch-size-multiple 1 \
--init-token 0 \
--arch bart_large \
--criterion sentence_prediction \
--num-classes $NUM_CLASSES \
--dropout 0.1 --attention-dropout 0.1 \
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \
--clip-norm 0.0 \
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
--max-epoch 10 \
--find-unused-parameters \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
```
For each of the GLUE task, you will need to use following cmd-line arguments:
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
---|---|---|---|---|---|---|---|---
`--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1
`--lr` | 5e-6 | 1e-5 | 1e-5 | 1e-5 | 5e-6 | 2e-5 | 2e-5 | 2e-5
`bsz` | 128 | 32 | 32 | 32 | 128 | 64 | 64 | 32
`--total-num-update` | 30968 | 33112 | 113272 | 1018 | 5233 | 1148 | 1334 | 1799
`--warmup-updates` | 1858 | 1986 | 6796 | 61 | 314 | 68 | 80 | 107
For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`.
**Note:**
a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--batch-size=32/64/128` depending on the task.
b) Above cmd-args and hyperparams are tested on Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`.
### Inference on GLUE task
After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
```python
from fairseq.models.bart import BARTModel
bart = BARTModel.from_pretrained(
'checkpoints/',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='RTE-bin'
)
label_fn = lambda label: bart.task.label_dictionary.string(
[label + bart.task.label_dictionary.nspecial]
)
ncorrect, nsamples = 0, 0
bart.cuda()
bart.eval()
with open('glue_data/RTE/dev.tsv') as fin:
fin.readline()
for index, line in enumerate(fin):
tokens = line.strip().split('\t')
sent1, sent2, target = tokens[1], tokens[2], tokens[3]
tokens = bart.encode(sent1, sent2)
prediction = bart.predict('sentence_classification_head', tokens).argmax().item()
prediction_label = label_fn(prediction)
ncorrect += int(prediction_label == target)
nsamples += 1
print('| Accuracy: ', float(ncorrect)/float(nsamples))
```
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment