Commit 12c90639 authored by “change”'s avatar “change”
Browse files

init

parent 417b607b
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, List, Optional
from torch import Tensor
import torch
import torch.nn as nn
from fairseq.models import (
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import (
base_architecture,
Embedding,
TransformerModel,
TransformerEncoder,
TransformerDecoder,
)
from fairseq.modules import (
TransformerDecoderLayer,
)
logger = logging.getLogger(__name__)
@register_model("laser_transformer")
class LaserTransformerModel(FairseqEncoderDecoderModel):
"""Train Transformer for LASER task
Requires --task laser
"""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens=None,
tgt_tokens=None,
tgt_lengths=None,
target_language_id=-1,
dataset_name="",
):
laser_encoder_out = self.encoder(src_tokens, src_lengths)
return self.decoder(
prev_output_tokens, laser_encoder_out, lang_id=target_language_id
)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
TransformerModel.add_args(parser)
parser.add_argument(
"--decoder-lang-embed-dim",
type=int,
metavar="N",
help="decoder language embedding dimension",
)
@classmethod
def build_model(cls, args, task):
base_laser_transformer_architecture(args)
num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0
def load_embed_tokens(dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
encoder_embed_tokens = load_embed_tokens(
task.source_dictionary, args.encoder_embed_dim
)
decoder_embed_tokens = load_embed_tokens(
task.target_dictionary, args.decoder_embed_dim
)
num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0
encoder = LaserTransformerEncoder(
args, task.source_dictionary, encoder_embed_tokens
)
decoder = LaserTransformerDecoder(
args,
task.target_dictionary,
decoder_embed_tokens,
num_langs=num_langs,
lang_embed_dim=args.decoder_lang_embed_dim,
)
return cls(encoder, decoder)
class LaserTransformerEncoder(TransformerEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, src_tokens, *args, **kwargs):
encoder_out = super().forward(src_tokens, *args, **kwargs)
x = encoder_out["encoder_out"][0] # T x B x C
padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1)
if padding_mask.any():
x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x)
# Build the sentence embedding by max-pooling over the encoder outputs
sentemb = x.max(dim=0)[0]
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `foward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
return {"sentemb": [sentemb]} # B x C
@torch.jit.export
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
"""
Same as the one in transformer.py, with new_sentemb
"""
if len(encoder_out["sentemb"]) == 0:
new_sentemb = []
else:
new_sentemb = [encoder_out["sentemb"][0].index_select(0, new_order)]
return {
"sentemb": new_sentemb, # B x C
}
class LaserTransformerDecoder(TransformerDecoder):
def __init__(self, args, dictionary, *kargs, **kwargs):
self.num_langs = kwargs.get("num_langs", 1)
self.lang_embed_dim = kwargs.get("lang_embed_dim", 0)
kwargs.pop("num_langs", None)
kwargs.pop("lang_embed_dim", None)
super().__init__(args, dictionary, *kargs, **kwargs, no_encoder_attn=True)
if self.lang_embed_dim == 0:
self.embed_lang = None
else:
self.embed_lang = nn.Embedding(self.num_langs, self.lang_embed_dim)
nn.init.uniform_(self.embed_lang.weight, -0.1, 0.1)
if self.output_projection is not None:
laser_output_embed_dim = (
self.output_embed_dim + self.lang_embed_dim + args.encoder_embed_dim
)
self.output_projection = nn.Linear(
laser_output_embed_dim, len(dictionary), bias=False
)
nn.init.normal_(
self.output_projection.weight,
mean=0,
std=laser_output_embed_dim ** -0.5,
)
def build_decoder_layer(self, args, no_encoder_attn=False):
decoder_embed_dim = args.decoder_embed_dim
args.decoder_embed_dim = (
decoder_embed_dim + self.lang_embed_dim + args.encoder_embed_dim
)
res = TransformerDecoderLayer(args, no_encoder_attn=True)
args.decoder_embed_dim = decoder_embed_dim
return res
def extract_features(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
lang_id: Optional[int] = None,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
if alignment_layer is None:
alignment_layer = self.num_layers - 1
# embed positions
positions = (
self.embed_positions(
prev_output_tokens, incremental_state=incremental_state
)
if self.embed_positions is not None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
bsz, seqlen = prev_output_tokens.size()
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.quant_noise is not None:
x = self.quant_noise(x)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if self.embed_lang is not None:
lang_ids = prev_output_tokens.data.new_full((bsz,), lang_id)
langemb = self.embed_lang(lang_ids)
langemb = langemb.unsqueeze(0)
repeat_vals = [x.shape[0] // langemb.shape[0]] + [-1] * (
len(langemb.shape) - 1
)
x = torch.cat((x, langemb.expand(*repeat_vals)), dim=-1)
sentemb = encoder_out["sentemb"][0]
sentemb = sentemb.unsqueeze(0)
repeat_vals = [x.shape[0] // sentemb.shape[0]] + [-1] * (len(sentemb.shape) - 1)
x = torch.cat((x, sentemb.expand(*repeat_vals)), dim=-1)
self_attn_padding_mask: Optional[Tensor] = None
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
# decoder layers
attn: Optional[Tensor] = None
inner_states: List[Optional[Tensor]] = [x]
for idx, layer in enumerate(self.layers):
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
x, layer_attn, _ = layer(
x,
None,
None,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=bool((idx == alignment_layer)),
need_head_weights=bool((idx == alignment_layer)),
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float().to(x)
if attn is not None:
if alignment_heads is not None:
attn = attn[:alignment_heads]
# average probabilities over heads
attn = attn.mean(dim=0)
if self.layer_norm is not None:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, {"attn": [attn], "inner_states": inner_states}
def forward(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
lang_id: Optional[int] = None,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
assert lang_id is not None
x, extra = self.extract_features(
prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
lang_id=lang_id,
)
if not features_only:
x = self.output_layer(x)
return x, extra
@register_model_architecture("laser_transformer", "laser_transformer")
def base_laser_transformer_architecture(args):
base_architecture(args)
args.decoder_lang_embed_dim = getattr(args, "decoder_lang_embed_dim", 0)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import numpy as np
from fairseq.data import BaseWrapperDataset, FairseqDataset, iterators
class MultiItr(object):
def __init__(self, itr):
self.itr = itr
self._counts = [0 for x in itr]
def __len__(self):
return sum(len(itr) for itr in self.itr)
def __iter__(self):
return self
def __next__(self):
ratios = [count / len(itr) for count, itr in zip(self._counts, self.itr)]
idx = ratios.index(min(ratios))
self._counts[idx] += 1
return next(self.itr[idx])
class MultidatasetEpochBatchIterator(iterators.EpochBatchIterating):
"""A wrapper around multiple epoch batch iterators."""
def __init__(
self,
dataset,
batch_sampler,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
):
assert isinstance(dataset, OrderedDict)
assert len(dataset)
assert isinstance(dataset[next(iter(dataset))], FairseqDataset)
self.iterators = []
self.epoch = epoch
for key, dt in dataset.items():
epoch_iter = iterators.EpochBatchIterator(
dataset=dt,
collate_fn=dt.collater,
batch_sampler=batch_sampler[key],
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=0,
epoch=epoch,
)
self.iterators.append(epoch_iter)
def __len__(self):
return sum(len(itr) for itr in self.iterators)
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
# `self.epoch += 1` should be handled by underlying `EpochBatchIterator`s.
return MultiItr(
[
itr.next_epoch_itr(
shuffle=shuffle, fix_batches_to_gpus=fix_batches_to_gpus
)
for itr in self.iterators
]
)
def end_of_epoch(self):
return all(itr.end_of_epoch() for itr in self.iterators)
@property
def next_epoch_idx(self):
"""Return the epoch index after *next_epoch_itr* is called."""
epochs = [itr.next_epoch_idx for itr in self.iterators]
self.epoch = epochs[0]
assert all(epoch == self.epoch for epoch in epochs)
return self.epoch
@property
def iterations_in_epoch(self):
return sum(itr.iterations_in_epoch for itr in self.iterators)
def state_dict(self):
return {
"iterators": [it.state_dict() for it in self.iterators],
"epoch": self.epoch,
}
def load_state_dict(self, state_dict):
self.epoch = state_dict["epoch"]
for it, d in zip(self.iterators, state_dict["iterators"]):
it.load_state_dict(d)
class MultitaskDatasetWrapper(BaseWrapperDataset):
"""A wrapper for a multitask dataset."""
def __init__(self, dataset, target_language_id, sample=1.0, name=""):
super().__init__(dataset)
self.target_language_id = target_language_id
self.sample = sample
self.name = name
def collater(self, *args, **kwargs):
ans = self.dataset.collater(*args, **kwargs)
if "net_input" in ans:
ans["net_input"]["target_language_id"] = self.target_language_id
ans["net_input"]["dataset_name"] = self.name
return ans
def num_tokens(self, *args, **kwargs):
return self.dataset.num_tokens(*args, **kwargs)
def ordered_indices(self, *args, **kwargs):
indices = self.dataset.ordered_indices(*args, **kwargs)
# Hacky solution for sampling
size = int(self.sample * indices.shape[0])
return indices.take(np.sort(np.random.permutation(indices.shape[0])[:size]))
def size(self, index: int):
return self.dataset.size(index)
@property
def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
# Deep Transformers with Latent Depth (Li et al., 2020)
[https://arxiv.org/abs/2009.13102](https://arxiv.org/abs/2009.13102).
## Introduction
We present a probabilistic framework to automatically learn which layer(s) to use by learning the posterior distributions of layer selection. As an extension of this framework, we propose a novel method to train one shared Transformer network for multilingual machine translation with different layer selection posteriors for each language pair.
## Training a multilingual model with latent depth
Below is an example of training with latent depth in decoder for one-to-many (O2M) related languages. We use the same preprocessed (numberized and binarized) TED8 dataset as in [Balancing Training for Multilingual Neural Machine Translation (Wang et al., 2020)](https://github.com/cindyxinyiwang/multiDDS), which could be generated by [the script](https://github.com/cindyxinyiwang/multiDDS/blob/multiDDS/util_scripts/prepare_multilingual_data.sh) the author provided.
```bash
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
databin_dir=<path to binarized data>
fairseq-train ${databin_dir} \
--user-dir examples/latent_depth/latent_depth_src \
--lang-pairs "${lang_pairs_str}" \
--arch multilingual_transformer_iwslt_de_en \
--task multilingual_translation_latent_depth \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--share-encoders \
--share-decoders \
--decoder-langtok \
--share-decoder-input-output-embed \
--dropout 0.3 --attention-dropout 0.3 \
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--lr-scheduler inverse_sqrt --stop-min-lr 1e-9 --warmup-init-lr 1e-7 --warmup-updates 8000 \
--max-tokens 4096 --update-freq 1 \
--lr 0.0015 \
--clip-norm 1.0 \
--seed 2 \
--ddp-backend=legacy_ddp \
--encoder-layers 12 \
--decoder-layers 24 \
--decoder-latent-layer \
--sparsity-weight 0.1 \
--anneal-updates 5000 \
--soft-update 500 \
--target-layers 12 \
--share-weight 0.1
```
## Inference command
```bash
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
databin_dir=<path to binarized data>
model_path=<path to checkpoint>
src_lang=<source language to translate from>
tgt_lang=<target language to translate to>
gen_data=<name of data split, e.g. valid, test, etc>
fairseq-generate ${databin_dir} \
--path ${model_path} \
--task multilingual_translation_latent_depth \
--decoder-latent-layer \
--lang-pairs "${lang_pairs_str}" \
-s ${src_lang} -t ${tgt_lang} \
--gen-subset $gen_data \
--scoring sacrebleu \
--remove-bpe 'sentencepiece' \
--lenpen 1.0 \
--beam 5 \
--decoder-langtok \
--max-tokens 4096
```
## Citation
```bibtex
@article{li2020deep,
title={Deep Transformers with Latent Depth},
author={Li, Xian and Stickland, Asa Cooper and Tang, Yuqing and Kong, Xiang},
journal={arXiv preprint arXiv:2009.13102},
year={2020}
}
```
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import multilingual_translation_latent_depth # noqa
from .loss import latent_depth # noqa
from .models import latent_multilingual_transformer # noqa
from .modules import latent_layers # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch.nn.modules.loss import _Loss
class LatentLayersKLLoss(_Loss):
def __init__(self, args):
super().__init__()
self.args = args
def forward(self, layer_samples, lang_idx, update_num, sample_size):
prior = self.args.prior
samples = layer_samples[lang_idx]
eps = 1e-7
if prior == "uniform":
# uniform prior
kl_loss = (samples * (torch.log(samples + eps) - math.log(0.5))).sum(-1)
elif prior == "agged_posterior":
# aggregated posterior
y_t = torch.stack([x.detach() for x in layer_samples], dim=0)
agged_q = torch.sum(y_t, dim=0)
row_norm = agged_q.sum(-1)
normed_agg_q = agged_q / row_norm
kl_loss = (
samples * (torch.log(samples + eps) - torch.log(normed_agg_q + eps))
).sum(-1)
else:
raise NotImplementedError("The specified prior is not implemented.")
# normalized by number of layers
kl_loss /= layer_samples[0].size()[0]
kl_weight = min(
self.args.sparsity_weight,
(update_num - self.args.soft_update)
* self.args.sparsity_weight
/ self.args.anneal_updates,
)
kl_loss *= kl_weight * sample_size
return kl_loss
class LatentLayersSparsityLoss(_Loss):
def __init__(self, args):
super().__init__()
self.args = args
def is_valid(self, update_num):
if self.args.target_layers <= 0:
return False
return update_num > (self.args.soft_update + self.args.anneal_updates)
def forward(self, layer_samples_list, update_num, sample_size):
batch_loss = 0
share_loss = 0
global_sparsity_loss = 0
layer_samples = torch.stack(layer_samples_list, dim=0)
if (
self.args.target_layers > 0 or self.args.share_weight > 0
) and update_num > (self.args.soft_update + self.args.anneal_updates):
# anneal sparsity weight
if update_num < (self.args.anneal_updates + self.args.soft_update):
weight_anneal = 0
elif update_num < (2 * self.args.anneal_updates + self.args.soft_update):
weight_anneal = (
(update_num - self.args.soft_update - self.args.anneal_updates)
* self.args.share_weight
/ self.args.anneal_updates
)
else:
weight_anneal = 1
# compute ratio among languages
layer_utilization = torch.sum(layer_samples, dim=0)
layer_utilization /= layer_samples.size()[0]
if self.args.share_weight > 0:
# encouraging sharing across languages
share_loss = sum(
-1.0 * v * math.log(v) for v in layer_utilization if v > 0
)
batch_loss += (
weight_anneal * self.args.share_weight * sample_size * share_loss
)
if self.args.target_layers > 0:
# computed expected number of layers selected
expeted_layers = sum(layer_utilization)
# compute l2 loss wrt target number of layers
global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2
batch_loss += (
weight_anneal
* self.args.share_weight
* sample_size
* global_sparsity_loss
)
return batch_loss
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.models import register_model, register_model_architecture
from fairseq.models.multilingual_transformer import MultilingualTransformerModel
from fairseq.models.transformer import (
TransformerDecoder,
TransformerEncoder,
base_architecture,
)
from .latent_transformer import LatentTransformerDecoder, LatentTransformerEncoder
@register_model("latent_multilingual_transformer")
class LatentMultilingualTransformerModel(MultilingualTransformerModel):
"""A variant of standard multilingual Transformer models which encoder and/or
decoders supports latent depth, as is in "Deep Transformer with Latent Depth"
(https://arxiv.org/abs/2009.13102).
"""
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
MultilingualTransformerModel.add_args(parser)
parser.add_argument(
'--soft-select',
action='store_true',
help='use soft samples in training an inference',
)
parser.add_argument(
'--sampling-tau',
type=float,
default=5.,
help='sampling temperature',
)
@classmethod
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
if is_encoder:
if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
return LatentTransformerEncoder(
args, lang_dict, embed_tokens, num_logits=len(langs)
)
else:
return TransformerEncoder(args, lang_dict, embed_tokens)
else:
if hasattr(args, "decoder_latent_layer") and args.decoder_latent_layer:
return LatentTransformerDecoder(
args, lang_dict, embed_tokens, num_logits=len(langs)
)
else:
return TransformerDecoder(args, lang_dict, embed_tokens)
@register_model_architecture(
"latent_multilingual_transformer", "latent_multilingual_transformer"
)
def latent_multilingual_architecture(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.decoder_layers = getattr(args, "decoder_layers", 24)
args.share_encoders = getattr(args, "share_encoders", True)
args.share_decoders = getattr(args, "share_decoders", True)
args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", True)
args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", True)
base_architecture(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional
import torch.nn as nn
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import TransformerDecoder, TransformerEncoder
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
from torch import Tensor
from ..modules.latent_layers import LayerSelect
class LatentTransformerEncoder(TransformerEncoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerEncoder.
"""
def __init__(self, args, dictionary, embed_tokens, num_logits=1):
self.num_logits = num_logits
self.num_layers = args.encoder_layers
super().__init__(args, dictionary, embed_tokens)
self.layer_select = LayerSelect(
num_layers=self.num_layers,
num_logits=self.num_logits,
soft_select=getattr(args, "soft_select", False),
sampling_tau=getattr(args, "sampling_tau", 5.),
)
self.lang_idx = None
self.layers = nn.ModuleList(
[self._build_encoder_layer(args, idx) for idx in range(args.encoder_layers)]
)
def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx
def _build_encoder_layer(self, args, idx=None):
return LatentTransformerEncoderLayer(args, idx, layer_select=self.layer_select)
def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False):
self.layer_select.sample(self.lang_idx)
return super().forward(src_tokens, src_lengths, return_all_hiddens)
class LatentTransformerEncoderLayer(TransformerEncoderLayer):
"""Encoder layer with each (non_residual) block weighted by samples of Bernouli
or Gumbel Signmoid samples.
Args:
args (argparse.Namespace): parsed command-line arguments from standard
TransformerEncoderLayer.
idx (int): layer index (used to retrieve samples).
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method.
"""
def __init__(self, args, idx, layer_select=None):
super().__init__(args)
self.idx = idx
self.layer_select = layer_select
def residual_connection(self, x, residual):
return residual + x * self.layer_select(self.idx)
class LatentTransformerDecoder(TransformerDecoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerDecoder.
"""
def __init__(
self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1
):
self.num_logits = num_logits
self.num_layers = args.decoder_layers
super().__init__(
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
)
self.layer_select = LayerSelect(
num_layers=self.num_layers,
num_logits=self.num_logits,
soft_select=getattr(args, "soft_select", False),
sampling_tau=getattr(args, "sampling_tau", 5.),
)
self.lang_idx = None
self.layers = nn.ModuleList(
[
self._build_decoder_layer(args, no_encoder_attn, idx)
for idx in range(args.decoder_layers)
]
)
def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx
def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None):
return LatentTransformerDecoderLayer(
args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn
)
def forward(
self,
prev_output_tokens,
encoder_out: Optional[EncoderOut] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
):
self.layer_select.sample(self.lang_idx)
return super().forward(
prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
features_only=features_only,
alignment_layer=alignment_layer,
src_lengths=src_lengths,
return_all_hiddens=return_all_hiddens,
)
class LatentTransformerDecoderLayer(TransformerDecoderLayer):
"""Decoder layer with each (non_residual) block weighted by samples of Bernouli
or Gumbel Signmoid samples.
Args:
args (argparse.Namespace): parsed command-line arguments from standard
TransformerDecoderLayer.
idx (int): layer index (used to retrieve samples).
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method.
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self,
args,
idx,
layer_select=None,
no_encoder_attn=False,
add_bias_kv=False,
add_zero_attn=False,
):
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
self.idx = idx
self.layer_select = layer_select
def residual_connection(self, x, residual):
return residual + x * self.layer_select(self.idx)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment