Commit c394d7d1 authored by “change”'s avatar “change”
Browse files

init

parents
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .laser_task import * # noqa
from .laser_lstm import * # noqa
from .laser_transformer import * # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import options, utils
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
@register_model("laser_lstm")
class LSTMModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens=None,
tgt_tokens=None,
tgt_lengths=None,
target_language_id=None,
dataset_name="",
):
assert target_language_id is not None
src_encoder_out = self.encoder(src_tokens, src_lengths, dataset_name)
return self.decoder(
prev_output_tokens, src_encoder_out, lang_id=target_language_id
)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--dropout",
default=0.1,
type=float,
metavar="D",
help="dropout probability",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-embed-path",
default=None,
type=str,
metavar="STR",
help="path to pre-trained encoder embedding",
)
parser.add_argument(
"--encoder-hidden-size", type=int, metavar="N", help="encoder hidden size"
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="number of encoder layers"
)
parser.add_argument(
"--encoder-bidirectional",
action="store_true",
help="make all layers of encoder bidirectional",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-embed-path",
default=None,
type=str,
metavar="STR",
help="path to pre-trained decoder embedding",
)
parser.add_argument(
"--decoder-hidden-size", type=int, metavar="N", help="decoder hidden size"
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="number of decoder layers"
)
parser.add_argument(
"--decoder-out-embed-dim",
type=int,
metavar="N",
help="decoder output embedding dimension",
)
parser.add_argument(
"--decoder-zero-init",
type=str,
metavar="BOOL",
help="initialize the decoder hidden/cell state to zero",
)
parser.add_argument(
"--decoder-lang-embed-dim",
type=int,
metavar="N",
help="decoder language embedding dimension",
)
parser.add_argument(
"--fixed-embeddings",
action="store_true",
help="keep embeddings fixed (ENCODER ONLY)",
) # TODO Also apply to decoder embeddings?
# Granular dropout settings (if not specified these default to --dropout)
parser.add_argument(
"--encoder-dropout-in",
type=float,
metavar="D",
help="dropout probability for encoder input embedding",
)
parser.add_argument(
"--encoder-dropout-out",
type=float,
metavar="D",
help="dropout probability for encoder output",
)
parser.add_argument(
"--decoder-dropout-in",
type=float,
metavar="D",
help="dropout probability for decoder input embedding",
)
parser.add_argument(
"--decoder-dropout-out",
type=float,
metavar="D",
help="dropout probability for decoder output",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure that all args are properly defaulted (in case there are any new ones)
base_architecture(args)
def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
embed_dict = utils.parse_embedding(embed_path)
utils.print_embed_overlap(embed_dict, dictionary)
return utils.load_embedding(embed_dict, dictionary, embed_tokens)
pretrained_encoder_embed = None
if args.encoder_embed_path:
pretrained_encoder_embed = load_pretrained_embedding_from_file(
args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim
)
pretrained_decoder_embed = None
if args.decoder_embed_path:
pretrained_decoder_embed = load_pretrained_embedding_from_file(
args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim
)
num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0
encoder = LSTMEncoder(
dictionary=task.source_dictionary,
embed_dim=args.encoder_embed_dim,
hidden_size=args.encoder_hidden_size,
num_layers=args.encoder_layers,
dropout_in=args.encoder_dropout_in,
dropout_out=args.encoder_dropout_out,
bidirectional=args.encoder_bidirectional,
pretrained_embed=pretrained_encoder_embed,
fixed_embeddings=args.fixed_embeddings,
)
decoder = LSTMDecoder(
dictionary=task.target_dictionary,
embed_dim=args.decoder_embed_dim,
hidden_size=args.decoder_hidden_size,
out_embed_dim=args.decoder_out_embed_dim,
num_layers=args.decoder_layers,
dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out,
zero_init=options.eval_bool(args.decoder_zero_init),
encoder_embed_dim=args.encoder_embed_dim,
encoder_output_units=encoder.output_units,
pretrained_embed=pretrained_decoder_embed,
num_langs=num_langs,
lang_embed_dim=args.decoder_lang_embed_dim,
)
return cls(encoder, decoder)
class LSTMEncoder(FairseqEncoder):
"""LSTM encoder."""
def __init__(
self,
dictionary,
embed_dim=512,
hidden_size=512,
num_layers=1,
dropout_in=0.1,
dropout_out=0.1,
bidirectional=False,
left_pad=True,
pretrained_embed=None,
padding_value=0.0,
fixed_embeddings=False,
):
super().__init__(dictionary)
self.num_layers = num_layers
self.dropout_in = dropout_in
self.dropout_out = dropout_out
self.bidirectional = bidirectional
self.hidden_size = hidden_size
num_embeddings = len(dictionary)
self.padding_idx = dictionary.pad()
if pretrained_embed is None:
self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
else:
self.embed_tokens = pretrained_embed
if fixed_embeddings:
self.embed_tokens.weight.requires_grad = False
self.lstm = LSTM(
input_size=embed_dim,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=self.dropout_out if num_layers > 1 else 0.0,
bidirectional=bidirectional,
)
self.left_pad = left_pad
self.padding_value = padding_value
self.output_units = hidden_size
if bidirectional:
self.output_units *= 2
def forward(self, src_tokens, src_lengths, dataset_name):
if self.left_pad:
# convert left-padding to right-padding
src_tokens = utils.convert_padding_direction(
src_tokens,
self.padding_idx,
left_to_right=True,
)
bsz, seqlen = src_tokens.size()
# embed tokens
x = self.embed_tokens(src_tokens)
x = F.dropout(x, p=self.dropout_in, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# pack embedded source tokens into a PackedSequence
try:
packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist())
except BaseException:
raise Exception(f"Packing failed in dataset {dataset_name}")
# apply LSTM
if self.bidirectional:
state_size = 2 * self.num_layers, bsz, self.hidden_size
else:
state_size = self.num_layers, bsz, self.hidden_size
h0 = x.data.new(*state_size).zero_()
c0 = x.data.new(*state_size).zero_()
packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0))
# unpack outputs and apply dropout
x, _ = nn.utils.rnn.pad_packed_sequence(
packed_outs, padding_value=self.padding_value
)
x = F.dropout(x, p=self.dropout_out, training=self.training)
assert list(x.size()) == [seqlen, bsz, self.output_units]
if self.bidirectional:
def combine_bidir(outs):
return torch.cat(
[
torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view(
1, bsz, self.output_units
)
for i in range(self.num_layers)
],
dim=0,
)
final_hiddens = combine_bidir(final_hiddens)
final_cells = combine_bidir(final_cells)
encoder_padding_mask = src_tokens.eq(self.padding_idx).t()
# Set padded outputs to -inf so they are not selected by max-pooling
padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1)
if padding_mask.any():
x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x)
# Build the sentence embedding by max-pooling over the encoder outputs
sentemb = x.max(dim=0)[0]
return {
"sentemb": sentemb,
"encoder_out": (x, final_hiddens, final_cells),
"encoder_padding_mask": encoder_padding_mask
if encoder_padding_mask.any()
else None,
}
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict["sentemb"] = encoder_out_dict["sentemb"].index_select(
0, new_order
)
encoder_out_dict["encoder_out"] = tuple(
eo.index_select(1, new_order) for eo in encoder_out_dict["encoder_out"]
)
if encoder_out_dict["encoder_padding_mask"] is not None:
encoder_out_dict["encoder_padding_mask"] = encoder_out_dict[
"encoder_padding_mask"
].index_select(1, new_order)
return encoder_out_dict
def max_positions(self):
"""Maximum input length supported by the encoder."""
return int(1e5) # an arbitrary large number
class LSTMDecoder(FairseqIncrementalDecoder):
"""LSTM decoder."""
def __init__(
self,
dictionary,
embed_dim=512,
hidden_size=512,
out_embed_dim=512,
num_layers=1,
dropout_in=0.1,
dropout_out=0.1,
zero_init=False,
encoder_embed_dim=512,
encoder_output_units=512,
pretrained_embed=None,
num_langs=1,
lang_embed_dim=0,
):
super().__init__(dictionary)
self.dropout_in = dropout_in
self.dropout_out = dropout_out
self.hidden_size = hidden_size
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
if pretrained_embed is None:
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
else:
self.embed_tokens = pretrained_embed
self.layers = nn.ModuleList(
[
LSTMCell(
input_size=encoder_output_units + embed_dim + lang_embed_dim
if layer == 0
else hidden_size,
hidden_size=hidden_size,
)
for layer in range(num_layers)
]
)
if hidden_size != out_embed_dim:
self.additional_fc = Linear(hidden_size, out_embed_dim)
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
if zero_init:
self.sentemb2init = None
else:
self.sentemb2init = Linear(
encoder_output_units, 2 * num_layers * hidden_size
)
if lang_embed_dim == 0:
self.embed_lang = None
else:
self.embed_lang = nn.Embedding(num_langs, lang_embed_dim)
nn.init.uniform_(self.embed_lang.weight, -0.1, 0.1)
def forward(
self, prev_output_tokens, encoder_out_dict, incremental_state=None, lang_id=0
):
sentemb = encoder_out_dict["sentemb"]
encoder_out = encoder_out_dict["encoder_out"]
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
bsz, seqlen = prev_output_tokens.size()
# get outputs from encoder
encoder_outs, _, _ = encoder_out[:3]
srclen = encoder_outs.size(0)
# embed tokens
x = self.embed_tokens(prev_output_tokens)
x = F.dropout(x, p=self.dropout_in, training=self.training)
# embed language identifier
if self.embed_lang is not None:
lang_ids = prev_output_tokens.data.new_full((bsz,), lang_id)
langemb = self.embed_lang(lang_ids)
# TODO Should we dropout here???
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# initialize previous states (or get from cache during incremental generation)
cached_state = utils.get_incremental_state(
self, incremental_state, "cached_state"
)
if cached_state is not None:
prev_hiddens, prev_cells, input_feed = cached_state
else:
num_layers = len(self.layers)
if self.sentemb2init is None:
prev_hiddens = [
x.data.new(bsz, self.hidden_size).zero_() for i in range(num_layers)
]
prev_cells = [
x.data.new(bsz, self.hidden_size).zero_() for i in range(num_layers)
]
else:
init = self.sentemb2init(sentemb)
prev_hiddens = [
init[:, (2 * i) * self.hidden_size : (2 * i + 1) * self.hidden_size]
for i in range(num_layers)
]
prev_cells = [
init[
:,
(2 * i + 1) * self.hidden_size : (2 * i + 2) * self.hidden_size,
]
for i in range(num_layers)
]
input_feed = x.data.new(bsz, self.hidden_size).zero_()
attn_scores = x.data.new(srclen, seqlen, bsz).zero_()
outs = []
for j in range(seqlen):
if self.embed_lang is None:
input = torch.cat((x[j, :, :], sentemb), dim=1)
else:
input = torch.cat((x[j, :, :], sentemb, langemb), dim=1)
for i, rnn in enumerate(self.layers):
# recurrent cell
hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))
# hidden state becomes the input to the next layer
input = F.dropout(hidden, p=self.dropout_out, training=self.training)
# save state for next time step
prev_hiddens[i] = hidden
prev_cells[i] = cell
out = hidden
out = F.dropout(out, p=self.dropout_out, training=self.training)
# input feeding
input_feed = out
# save final output
outs.append(out)
# cache previous states (no-op except during incremental generation)
utils.set_incremental_state(
self,
incremental_state,
"cached_state",
(prev_hiddens, prev_cells, input_feed),
)
# collect outputs across time steps
x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)
# T x B x C -> B x T x C
x = x.transpose(1, 0)
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen
attn_scores = attn_scores.transpose(0, 2)
# project back to size of vocabulary
if hasattr(self, "additional_fc"):
x = self.additional_fc(x)
x = F.dropout(x, p=self.dropout_out, training=self.training)
x = self.fc_out(x)
return x, attn_scores
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
cached_state = utils.get_incremental_state(
self, incremental_state, "cached_state"
)
if cached_state is None:
return
def reorder_state(state):
if isinstance(state, list):
return [reorder_state(state_i) for state_i in state]
return state.index_select(0, new_order)
new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, "cached_state", new_state)
def max_positions(self):
"""Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.uniform_(m.weight, -0.1, 0.1)
nn.init.constant_(m.weight[padding_idx], 0)
return m
def LSTM(input_size, hidden_size, **kwargs):
m = nn.LSTM(input_size, hidden_size, **kwargs)
for name, param in m.named_parameters():
if "weight" in name or "bias" in name:
param.data.uniform_(-0.1, 0.1)
return m
def LSTMCell(input_size, hidden_size, **kwargs):
m = nn.LSTMCell(input_size, hidden_size, **kwargs)
for name, param in m.named_parameters():
if "weight" in name or "bias" in name:
param.data.uniform_(-0.1, 0.1)
return m
def Linear(in_features, out_features, bias=True, dropout=0):
"""Weight-normalized Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias)
m.weight.data.uniform_(-0.1, 0.1)
if bias:
m.bias.data.uniform_(-0.1, 0.1)
return m
@register_model_architecture("laser_lstm", "laser_lstm")
def base_architecture(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_hidden_size = getattr(
args, "encoder_hidden_size", args.encoder_embed_dim
)
args.encoder_layers = getattr(args, "encoder_layers", 1)
args.encoder_bidirectional = getattr(args, "encoder_bidirectional", False)
args.encoder_dropout_in = getattr(args, "encoder_dropout_in", args.dropout)
args.encoder_dropout_out = getattr(args, "encoder_dropout_out", args.dropout)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_hidden_size = getattr(
args, "decoder_hidden_size", args.decoder_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 1)
args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512)
args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout)
args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout)
args.decoder_zero_init = getattr(args, "decoder_zero_init", "0")
args.decoder_lang_embed_dim = getattr(args, "decoder_lang_embed_dim", 0)
args.fixed_embeddings = getattr(args, "fixed_embeddings", False)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict, defaultdict
import json
import os
import logging
from fairseq import options, models
from fairseq.data import (
data_utils,
Dictionary,
LanguagePairDataset,
IndexedDataset,
FairseqDataset,
)
from .multitask_data_utils import (
MultitaskDatasetWrapper,
MultidatasetEpochBatchIterator,
)
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task("laser")
class LaserTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument(
"configfile", metavar="PATH", help="dataset configuration file in json"
)
parser.add_argument(
"--weighting-alpha",
type=float,
default=None,
help="alpha for automatic weighting",
)
parser.add_argument(
"--raw-text", action="store_true", help="load raw text dataset"
)
parser.add_argument(
"--left-pad-source",
default="True",
type=str,
metavar="BOOL",
help="pad the source on the left (default: True)",
)
parser.add_argument(
"--left-pad-target",
default="False",
type=str,
metavar="BOOL",
help="pad the target on the left (default: False)",
)
parser.add_argument(
"--max-source-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the source sequence",
)
parser.add_argument(
"--max-target-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the target sequence",
)
def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks):
super().__init__(args)
self.config = config
self.src_dictionary = src_dictionary
self.tgt_dictionary = tgt_dictionary
self.num_tasks = num_tasks
@classmethod
def setup_task(cls, args, **kwargs):
with open(args.configfile, "r") as f:
config = json.load(f)
num_tasks = max(dataset["id"] for dataset in config["train"]) + 1
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
src_dictionary = Dictionary.load(config["src_vocab"])
tgt_dictionary = Dictionary.load(config["tgt_vocab"])
logger.info(
"| src Dictionary {} : {} types".format(
config["src_vocab"], len(src_dictionary)
)
)
logger.info(
"| tgt Dictionary {} : {} types".format(
config["tgt_vocab"], len(tgt_dictionary)
)
)
return cls(args, config, src_dictionary, tgt_dictionary, num_tasks)
# Experimental overriding for backtranslation
def build_model(self, args):
model = models.build_model(args, self)
return model
def dataset(self, split):
if split not in self.datasets:
raise KeyError("Dataset not loaded: " + split)
return self.datasets[split]
def load_dataset(self, split, epoch=1, **kwargs):
"""Load a dataset split."""
def indexed_dataset(path, dictionary):
if self.args.raw_text:
raise Exception("Unable to handle raw text.")
dataset = IndexedDataset(path, fix_lua_indexing=True)
return dataset
pair_datasets = OrderedDict()
if split == "valid":
self.datasets[split] = pair_datasets
return
if split not in self.config:
raise FileNotFoundError(
"Dataset not found in config file: {}".format(split)
)
size_by_corpus = defaultdict(int)
size_sum = 0
size_sum_with_subsampling = 0
init_pair_datasets = {}
for dataset_config in self.config[split]:
src_path = os.path.dirname(dataset_config["src"])
corpus_name = src_path.split("/")[-2]
language_pair_name = src_path.split("/")[-1]
pair_datasets_key = corpus_name + "-" + language_pair_name
logger.info(f"loading... {pair_datasets_key}")
if "src" in dataset_config:
src_dataset = indexed_dataset(
dataset_config["src"], self.src_dictionary
)
else:
src_dataset = None
if "tgt" in dataset_config:
tgt_dataset = indexed_dataset(
dataset_config["tgt"], self.tgt_dictionary
)
else:
tgt_dataset = None
dataset = LanguagePairDataset(
src_dataset,
src_dataset.sizes,
self.src_dictionary,
tgt_dataset,
tgt_dataset.sizes,
self.tgt_dictionary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
)
if pair_datasets_key in init_pair_datasets:
logger.warning(
f"Ignoring already added {pair_datasets_key}. "
f"Consider using `sample` key in order to upsample."
)
else:
init_pair_datasets[pair_datasets_key] = {
"dataset": dataset,
"sample": dataset_config.get("sample", None),
"id": dataset_config.get("id", None),
"len": len(dataset),
}
length_sum = 0
weighted_freqs_sum = 0
freq_per_dataset = {}
vmax = 0
vmin = 1
weighted_freq_per_dataset = {}
if self.args.weighting_alpha:
for key in init_pair_datasets:
if init_pair_datasets[key]["sample"] is None:
length_sum += len(init_pair_datasets[key]["dataset"])
for key in init_pair_datasets:
if init_pair_datasets[key]["sample"] is None:
val = float(init_pair_datasets[key]["len"]) / length_sum
freq_per_dataset[key] = val
weighted_freqs_sum += val ** self.args.weighting_alpha
for key in freq_per_dataset:
val = (
freq_per_dataset[key] ** self.args.weighting_alpha
/ weighted_freqs_sum
)
vmin = min(vmin, val)
vmax = max(vmax, val)
weighted_freq_per_dataset[key] = val
for pair_datasets_key in init_pair_datasets:
dataset_config = init_pair_datasets[pair_datasets_key]
dataset = dataset_config["dataset"]
sample = dataset_config["sample"]
if sample is None:
sample = 1.0
if pair_datasets_key in weighted_freq_per_dataset:
w = vmax / weighted_freq_per_dataset[pair_datasets_key]
sample = w
sample = round(sample)
initial_sample = sample
initial_pair_datasets_key = pair_datasets_key
while sample >= 1.0:
assert (
pair_datasets_key not in pair_datasets
), f"{pair_datasets_key} already in"
size_sum_with_subsampling += len(dataset)
pair_datasets[pair_datasets_key] = MultitaskDatasetWrapper(
dataset, dataset_config.get("id", 0), 1.0, name=pair_datasets_key
)
size_sum += len(dataset)
sample -= 1.0
pair_datasets_key += "-up"
assert sample < 1e-6, f"sample remains > 0 {pair_datasets_key}"
logger.info(
f"added pair {initial_pair_datasets_key} length {len(dataset)} new_length = {len(dataset)*initial_sample}"
)
size_by_corpus[corpus_name] += len(dataset)
self.datasets[split] = pair_datasets
logger.info(
f"Datasets number = {len(self.datasets[split])} size = {size_sum} size_sum_with_subsampling = {size_sum_with_subsampling}"
)
@property
def source_dictionary(self):
return self.src_dictionary
@property
def target_dictionary(self):
return self.tgt_dictionary
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
):
assert isinstance(dataset, OrderedDict)
assert len(dataset)
assert isinstance(dataset[next(iter(dataset))], FairseqDataset)
# initialize the dataset with the correct starting epoch
for _, dt in dataset.items():
dt.set_epoch(epoch)
indices = OrderedDict()
batch_sampler = OrderedDict()
with data_utils.numpy_seed(seed + epoch):
for key, dt in dataset.items():
logger.info(f"\t ordered_indices {key}")
indices[key] = dt.ordered_indices()
# filter examples that are too large
if max_positions is not None:
for key, dt in dataset.items():
logger.info(f"\t filter_by_size {key}")
indices[key], ignored = dt.filter_indices_by_size(
indices[key], max_positions
)
for key, dt in dataset.items():
logger.info(f"\t batch_by_size {key}")
batch_sampler[key] = data_utils.batch_by_size(
indices[key],
dt.num_tokens,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
epoch_iter = MultidatasetEpochBatchIterator(
dataset=dataset,
batch_sampler=batch_sampler,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
)
return epoch_iter
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, List, Optional
from torch import Tensor
import torch
import torch.nn as nn
from fairseq.models import (
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import (
base_architecture,
Embedding,
TransformerModel,
TransformerEncoder,
TransformerDecoder,
)
from fairseq.modules import (
TransformerDecoderLayer,
)
logger = logging.getLogger(__name__)
@register_model("laser_transformer")
class LaserTransformerModel(FairseqEncoderDecoderModel):
"""Train Transformer for LASER task
Requires --task laser
"""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens=None,
tgt_tokens=None,
tgt_lengths=None,
target_language_id=-1,
dataset_name="",
):
laser_encoder_out = self.encoder(src_tokens, src_lengths)
return self.decoder(
prev_output_tokens, laser_encoder_out, lang_id=target_language_id
)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
TransformerModel.add_args(parser)
parser.add_argument(
"--decoder-lang-embed-dim",
type=int,
metavar="N",
help="decoder language embedding dimension",
)
@classmethod
def build_model(cls, args, task):
base_laser_transformer_architecture(args)
num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0
def load_embed_tokens(dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
encoder_embed_tokens = load_embed_tokens(
task.source_dictionary, args.encoder_embed_dim
)
decoder_embed_tokens = load_embed_tokens(
task.target_dictionary, args.decoder_embed_dim
)
num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0
encoder = LaserTransformerEncoder(
args, task.source_dictionary, encoder_embed_tokens
)
decoder = LaserTransformerDecoder(
args,
task.target_dictionary,
decoder_embed_tokens,
num_langs=num_langs,
lang_embed_dim=args.decoder_lang_embed_dim,
)
return cls(encoder, decoder)
class LaserTransformerEncoder(TransformerEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, src_tokens, *args, **kwargs):
encoder_out = super().forward(src_tokens, *args, **kwargs)
x = encoder_out["encoder_out"][0] # T x B x C
padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1)
if padding_mask.any():
x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x)
# Build the sentence embedding by max-pooling over the encoder outputs
sentemb = x.max(dim=0)[0]
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `foward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
return {"sentemb": [sentemb]} # B x C
@torch.jit.export
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
"""
Same as the one in transformer.py, with new_sentemb
"""
if len(encoder_out["sentemb"]) == 0:
new_sentemb = []
else:
new_sentemb = [encoder_out["sentemb"][0].index_select(0, new_order)]
return {
"sentemb": new_sentemb, # B x C
}
class LaserTransformerDecoder(TransformerDecoder):
def __init__(self, args, dictionary, *kargs, **kwargs):
self.num_langs = kwargs.get("num_langs", 1)
self.lang_embed_dim = kwargs.get("lang_embed_dim", 0)
kwargs.pop("num_langs", None)
kwargs.pop("lang_embed_dim", None)
super().__init__(args, dictionary, *kargs, **kwargs, no_encoder_attn=True)
if self.lang_embed_dim == 0:
self.embed_lang = None
else:
self.embed_lang = nn.Embedding(self.num_langs, self.lang_embed_dim)
nn.init.uniform_(self.embed_lang.weight, -0.1, 0.1)
if self.output_projection is not None:
laser_output_embed_dim = (
self.output_embed_dim + self.lang_embed_dim + args.encoder_embed_dim
)
self.output_projection = nn.Linear(
laser_output_embed_dim, len(dictionary), bias=False
)
nn.init.normal_(
self.output_projection.weight,
mean=0,
std=laser_output_embed_dim ** -0.5,
)
def build_decoder_layer(self, args, no_encoder_attn=False):
decoder_embed_dim = args.decoder_embed_dim
args.decoder_embed_dim = (
decoder_embed_dim + self.lang_embed_dim + args.encoder_embed_dim
)
res = TransformerDecoderLayer(args, no_encoder_attn=True)
args.decoder_embed_dim = decoder_embed_dim
return res
def extract_features(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
lang_id: Optional[int] = None,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
if alignment_layer is None:
alignment_layer = self.num_layers - 1
# embed positions
positions = (
self.embed_positions(
prev_output_tokens, incremental_state=incremental_state
)
if self.embed_positions is not None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
bsz, seqlen = prev_output_tokens.size()
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.quant_noise is not None:
x = self.quant_noise(x)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if self.embed_lang is not None:
lang_ids = prev_output_tokens.data.new_full((bsz,), lang_id)
langemb = self.embed_lang(lang_ids)
langemb = langemb.unsqueeze(0)
repeat_vals = [x.shape[0] // langemb.shape[0]] + [-1] * (
len(langemb.shape) - 1
)
x = torch.cat((x, langemb.expand(*repeat_vals)), dim=-1)
sentemb = encoder_out["sentemb"][0]
sentemb = sentemb.unsqueeze(0)
repeat_vals = [x.shape[0] // sentemb.shape[0]] + [-1] * (len(sentemb.shape) - 1)
x = torch.cat((x, sentemb.expand(*repeat_vals)), dim=-1)
self_attn_padding_mask: Optional[Tensor] = None
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
# decoder layers
attn: Optional[Tensor] = None
inner_states: List[Optional[Tensor]] = [x]
for idx, layer in enumerate(self.layers):
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
x, layer_attn, _ = layer(
x,
None,
None,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=bool((idx == alignment_layer)),
need_head_weights=bool((idx == alignment_layer)),
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float().to(x)
if attn is not None:
if alignment_heads is not None:
attn = attn[:alignment_heads]
# average probabilities over heads
attn = attn.mean(dim=0)
if self.layer_norm is not None:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, {"attn": [attn], "inner_states": inner_states}
def forward(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
lang_id: Optional[int] = None,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
assert lang_id is not None
x, extra = self.extract_features(
prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
lang_id=lang_id,
)
if not features_only:
x = self.output_layer(x)
return x, extra
@register_model_architecture("laser_transformer", "laser_transformer")
def base_laser_transformer_architecture(args):
base_architecture(args)
args.decoder_lang_embed_dim = getattr(args, "decoder_lang_embed_dim", 0)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import numpy as np
from fairseq.data import BaseWrapperDataset, FairseqDataset, iterators
class MultiItr(object):
def __init__(self, itr):
self.itr = itr
self._counts = [0 for x in itr]
def __len__(self):
return sum(len(itr) for itr in self.itr)
def __iter__(self):
return self
def __next__(self):
ratios = [count / len(itr) for count, itr in zip(self._counts, self.itr)]
idx = ratios.index(min(ratios))
self._counts[idx] += 1
return next(self.itr[idx])
class MultidatasetEpochBatchIterator(iterators.EpochBatchIterating):
"""A wrapper around multiple epoch batch iterators."""
def __init__(
self,
dataset,
batch_sampler,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
):
assert isinstance(dataset, OrderedDict)
assert len(dataset)
assert isinstance(dataset[next(iter(dataset))], FairseqDataset)
self.iterators = []
self.epoch = epoch
for key, dt in dataset.items():
epoch_iter = iterators.EpochBatchIterator(
dataset=dt,
collate_fn=dt.collater,
batch_sampler=batch_sampler[key],
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=0,
epoch=epoch,
)
self.iterators.append(epoch_iter)
def __len__(self):
return sum(len(itr) for itr in self.iterators)
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
# `self.epoch += 1` should be handled by underlying `EpochBatchIterator`s.
return MultiItr(
[
itr.next_epoch_itr(
shuffle=shuffle, fix_batches_to_gpus=fix_batches_to_gpus
)
for itr in self.iterators
]
)
def end_of_epoch(self):
return all(itr.end_of_epoch() for itr in self.iterators)
@property
def next_epoch_idx(self):
"""Return the epoch index after *next_epoch_itr* is called."""
epochs = [itr.next_epoch_idx for itr in self.iterators]
self.epoch = epochs[0]
assert all(epoch == self.epoch for epoch in epochs)
return self.epoch
@property
def iterations_in_epoch(self):
return sum(itr.iterations_in_epoch for itr in self.iterators)
def state_dict(self):
return {
"iterators": [it.state_dict() for it in self.iterators],
"epoch": self.epoch,
}
def load_state_dict(self, state_dict):
self.epoch = state_dict["epoch"]
for it, d in zip(self.iterators, state_dict["iterators"]):
it.load_state_dict(d)
class MultitaskDatasetWrapper(BaseWrapperDataset):
"""A wrapper for a multitask dataset."""
def __init__(self, dataset, target_language_id, sample=1.0, name=""):
super().__init__(dataset)
self.target_language_id = target_language_id
self.sample = sample
self.name = name
def collater(self, *args, **kwargs):
ans = self.dataset.collater(*args, **kwargs)
if "net_input" in ans:
ans["net_input"]["target_language_id"] = self.target_language_id
ans["net_input"]["dataset_name"] = self.name
return ans
def num_tokens(self, *args, **kwargs):
return self.dataset.num_tokens(*args, **kwargs)
def ordered_indices(self, *args, **kwargs):
indices = self.dataset.ordered_indices(*args, **kwargs)
# Hacky solution for sampling
size = int(self.sample * indices.shape[0])
return indices.take(np.sort(np.random.permutation(indices.shape[0])[:size]))
def size(self, index: int):
return self.dataset.size(index)
@property
def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
# Deep Transformers with Latent Depth (Li et al., 2020)
[https://arxiv.org/abs/2009.13102](https://arxiv.org/abs/2009.13102).
## Introduction
We present a probabilistic framework to automatically learn which layer(s) to use by learning the posterior distributions of layer selection. As an extension of this framework, we propose a novel method to train one shared Transformer network for multilingual machine translation with different layer selection posteriors for each language pair.
## Training a multilingual model with latent depth
Below is an example of training with latent depth in decoder for one-to-many (O2M) related languages. We use the same preprocessed (numberized and binarized) TED8 dataset as in [Balancing Training for Multilingual Neural Machine Translation (Wang et al., 2020)](https://github.com/cindyxinyiwang/multiDDS), which could be generated by [the script](https://github.com/cindyxinyiwang/multiDDS/blob/multiDDS/util_scripts/prepare_multilingual_data.sh) the author provided.
```bash
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
databin_dir=<path to binarized data>
fairseq-train ${databin_dir} \
--user-dir examples/latent_depth/latent_depth_src \
--lang-pairs "${lang_pairs_str}" \
--arch multilingual_transformer_iwslt_de_en \
--task multilingual_translation_latent_depth \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--share-encoders \
--share-decoders \
--decoder-langtok \
--share-decoder-input-output-embed \
--dropout 0.3 --attention-dropout 0.3 \
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--lr-scheduler inverse_sqrt --stop-min-lr 1e-9 --warmup-init-lr 1e-7 --warmup-updates 8000 \
--max-tokens 4096 --update-freq 1 \
--lr 0.0015 \
--clip-norm 1.0 \
--seed 2 \
--ddp-backend=legacy_ddp \
--encoder-layers 12 \
--decoder-layers 24 \
--decoder-latent-layer \
--sparsity-weight 0.1 \
--anneal-updates 5000 \
--soft-update 500 \
--target-layers 12 \
--share-weight 0.1
```
## Inference command
```bash
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
databin_dir=<path to binarized data>
model_path=<path to checkpoint>
src_lang=<source language to translate from>
tgt_lang=<target language to translate to>
gen_data=<name of data split, e.g. valid, test, etc>
fairseq-generate ${databin_dir} \
--path ${model_path} \
--task multilingual_translation_latent_depth \
--decoder-latent-layer \
--lang-pairs "${lang_pairs_str}" \
-s ${src_lang} -t ${tgt_lang} \
--gen-subset $gen_data \
--scoring sacrebleu \
--remove-bpe 'sentencepiece' \
--lenpen 1.0 \
--beam 5 \
--decoder-langtok \
--max-tokens 4096
```
## Citation
```bibtex
@article{li2020deep,
title={Deep Transformers with Latent Depth},
author={Li, Xian and Stickland, Asa Cooper and Tang, Yuqing and Kong, Xiang},
journal={arXiv preprint arXiv:2009.13102},
year={2020}
}
```
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import multilingual_translation_latent_depth # noqa
from .loss import latent_depth # noqa
from .models import latent_multilingual_transformer # noqa
from .modules import latent_layers # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch.nn.modules.loss import _Loss
class LatentLayersKLLoss(_Loss):
def __init__(self, args):
super().__init__()
self.args = args
def forward(self, layer_samples, lang_idx, update_num, sample_size):
prior = self.args.prior
samples = layer_samples[lang_idx]
eps = 1e-7
if prior == "uniform":
# uniform prior
kl_loss = (samples * (torch.log(samples + eps) - math.log(0.5))).sum(-1)
elif prior == "agged_posterior":
# aggregated posterior
y_t = torch.stack([x.detach() for x in layer_samples], dim=0)
agged_q = torch.sum(y_t, dim=0)
row_norm = agged_q.sum(-1)
normed_agg_q = agged_q / row_norm
kl_loss = (
samples * (torch.log(samples + eps) - torch.log(normed_agg_q + eps))
).sum(-1)
else:
raise NotImplementedError("The specified prior is not implemented.")
# normalized by number of layers
kl_loss /= layer_samples[0].size()[0]
kl_weight = min(
self.args.sparsity_weight,
(update_num - self.args.soft_update)
* self.args.sparsity_weight
/ self.args.anneal_updates,
)
kl_loss *= kl_weight * sample_size
return kl_loss
class LatentLayersSparsityLoss(_Loss):
def __init__(self, args):
super().__init__()
self.args = args
def is_valid(self, update_num):
if self.args.target_layers <= 0:
return False
return update_num > (self.args.soft_update + self.args.anneal_updates)
def forward(self, layer_samples_list, update_num, sample_size):
batch_loss = 0
share_loss = 0
global_sparsity_loss = 0
layer_samples = torch.stack(layer_samples_list, dim=0)
if (
self.args.target_layers > 0 or self.args.share_weight > 0
) and update_num > (self.args.soft_update + self.args.anneal_updates):
# anneal sparsity weight
if update_num < (self.args.anneal_updates + self.args.soft_update):
weight_anneal = 0
elif update_num < (2 * self.args.anneal_updates + self.args.soft_update):
weight_anneal = (
(update_num - self.args.soft_update - self.args.anneal_updates)
* self.args.share_weight
/ self.args.anneal_updates
)
else:
weight_anneal = 1
# compute ratio among languages
layer_utilization = torch.sum(layer_samples, dim=0)
layer_utilization /= layer_samples.size()[0]
if self.args.share_weight > 0:
# encouraging sharing across languages
share_loss = sum(
-1.0 * v * math.log(v) for v in layer_utilization if v > 0
)
batch_loss += (
weight_anneal * self.args.share_weight * sample_size * share_loss
)
if self.args.target_layers > 0:
# computed expected number of layers selected
expeted_layers = sum(layer_utilization)
# compute l2 loss wrt target number of layers
global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2
batch_loss += (
weight_anneal
* self.args.share_weight
* sample_size
* global_sparsity_loss
)
return batch_loss
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.models import register_model, register_model_architecture
from fairseq.models.multilingual_transformer import MultilingualTransformerModel
from fairseq.models.transformer import (
TransformerDecoder,
TransformerEncoder,
base_architecture,
)
from .latent_transformer import LatentTransformerDecoder, LatentTransformerEncoder
@register_model("latent_multilingual_transformer")
class LatentMultilingualTransformerModel(MultilingualTransformerModel):
"""A variant of standard multilingual Transformer models which encoder and/or
decoders supports latent depth, as is in "Deep Transformer with Latent Depth"
(https://arxiv.org/abs/2009.13102).
"""
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
MultilingualTransformerModel.add_args(parser)
parser.add_argument(
'--soft-select',
action='store_true',
help='use soft samples in training an inference',
)
parser.add_argument(
'--sampling-tau',
type=float,
default=5.,
help='sampling temperature',
)
@classmethod
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
if is_encoder:
if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
return LatentTransformerEncoder(
args, lang_dict, embed_tokens, num_logits=len(langs)
)
else:
return TransformerEncoder(args, lang_dict, embed_tokens)
else:
if hasattr(args, "decoder_latent_layer") and args.decoder_latent_layer:
return LatentTransformerDecoder(
args, lang_dict, embed_tokens, num_logits=len(langs)
)
else:
return TransformerDecoder(args, lang_dict, embed_tokens)
@register_model_architecture(
"latent_multilingual_transformer", "latent_multilingual_transformer"
)
def latent_multilingual_architecture(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.decoder_layers = getattr(args, "decoder_layers", 24)
args.share_encoders = getattr(args, "share_encoders", True)
args.share_decoders = getattr(args, "share_decoders", True)
args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", True)
args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", True)
base_architecture(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional
import torch.nn as nn
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import TransformerDecoder, TransformerEncoder
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
from torch import Tensor
from ..modules.latent_layers import LayerSelect
class LatentTransformerEncoder(TransformerEncoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerEncoder.
"""
def __init__(self, args, dictionary, embed_tokens, num_logits=1):
self.num_logits = num_logits
self.num_layers = args.encoder_layers
super().__init__(args, dictionary, embed_tokens)
self.layer_select = LayerSelect(
num_layers=self.num_layers,
num_logits=self.num_logits,
soft_select=getattr(args, "soft_select", False),
sampling_tau=getattr(args, "sampling_tau", 5.),
)
self.lang_idx = None
self.layers = nn.ModuleList(
[self._build_encoder_layer(args, idx) for idx in range(args.encoder_layers)]
)
def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx
def _build_encoder_layer(self, args, idx=None):
return LatentTransformerEncoderLayer(args, idx, layer_select=self.layer_select)
def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False):
self.layer_select.sample(self.lang_idx)
return super().forward(src_tokens, src_lengths, return_all_hiddens)
class LatentTransformerEncoderLayer(TransformerEncoderLayer):
"""Encoder layer with each (non_residual) block weighted by samples of Bernouli
or Gumbel Signmoid samples.
Args:
args (argparse.Namespace): parsed command-line arguments from standard
TransformerEncoderLayer.
idx (int): layer index (used to retrieve samples).
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method.
"""
def __init__(self, args, idx, layer_select=None):
super().__init__(args)
self.idx = idx
self.layer_select = layer_select
def residual_connection(self, x, residual):
return residual + x * self.layer_select(self.idx)
class LatentTransformerDecoder(TransformerDecoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerDecoder.
"""
def __init__(
self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1
):
self.num_logits = num_logits
self.num_layers = args.decoder_layers
super().__init__(
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
)
self.layer_select = LayerSelect(
num_layers=self.num_layers,
num_logits=self.num_logits,
soft_select=getattr(args, "soft_select", False),
sampling_tau=getattr(args, "sampling_tau", 5.),
)
self.lang_idx = None
self.layers = nn.ModuleList(
[
self._build_decoder_layer(args, no_encoder_attn, idx)
for idx in range(args.decoder_layers)
]
)
def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx
def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None):
return LatentTransformerDecoderLayer(
args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn
)
def forward(
self,
prev_output_tokens,
encoder_out: Optional[EncoderOut] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
):
self.layer_select.sample(self.lang_idx)
return super().forward(
prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
features_only=features_only,
alignment_layer=alignment_layer,
src_lengths=src_lengths,
return_all_hiddens=return_all_hiddens,
)
class LatentTransformerDecoderLayer(TransformerDecoderLayer):
"""Decoder layer with each (non_residual) block weighted by samples of Bernouli
or Gumbel Signmoid samples.
Args:
args (argparse.Namespace): parsed command-line arguments from standard
TransformerDecoderLayer.
idx (int): layer index (used to retrieve samples).
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method.
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self,
args,
idx,
layer_select=None,
no_encoder_attn=False,
add_bias_kv=False,
add_zero_attn=False,
):
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
self.idx = idx
self.layer_select = layer_select
def residual_connection(self, x, residual):
return residual + x * self.layer_select(self.idx)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
class LayerSelect(nn.Module):
"""Compute samples (from a Gumbel-Sigmoid distribution) which is used as
either (soft) weighting or (hard) selection of residual connection.
https://arxiv.org/abs/2009.13102
"""
def __init__(self, num_layers, num_logits, soft_select=False, sampling_tau=5.):
super(LayerSelect, self).__init__()
self.layer_logits = torch.nn.Parameter(
torch.Tensor(num_logits, num_layers),
requires_grad=True,
)
self.hard_select = not soft_select
self.tau = sampling_tau
self.detach_grad = False
self.layer_samples = [None] * num_logits
def sample(self, logit_idx):
"""To leverage the efficiency of distributed training, samples for all
layers are computed at once for each logit_idx. Logits are parameters
learnt independent of each other.
Args:
logit_idx: The index of logit parameters used for sampling.
"""
assert logit_idx is not None
self.samples = self._gumbel_sigmoid(
self.layer_logits[logit_idx, :].detach()
if self.detach_grad
else self.layer_logits[logit_idx, :],
dim=-1,
tau=self.tau,
hard=self.hard_select,
)
self.layer_samples[logit_idx] = self.samples
def forward(self, i):
sample = self.samples[i]
return sample
def _gumbel_sigmoid(
self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5
):
# ~Gumbel(0,1)
gumbels1 = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
.exponential_()
.log()
)
gumbels2 = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
.exponential_()
.log()
)
# Difference of two gumbels because we apply a sigmoid
gumbels1 = (logits + gumbels1 - gumbels2) / tau
y_soft = gumbels1.sigmoid()
if hard:
# Straight through.
y_hard = torch.zeros_like(
logits, memory_format=torch.legacy_contiguous_format
).masked_fill(y_soft > threshold, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
# Reparametrization trick.
ret = y_soft
return ret
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.tasks import register_task
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
from .loss.latent_depth import LatentLayersKLLoss, LatentLayersSparsityLoss
@register_task("multilingual_translation_latent_depth")
class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
"""A task for multiple translation with latent depth.
See `"Deep Transformer with Latent Depth"
(Li et al., 2020) <https://arxiv.org/pdf/2009.13102.pdf>`_.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
MultilingualTranslationTask.add_args(parser)
parser.add_argument('--encoder-latent-layer', action='store_true', help='latent layer selection in encoder')
parser.add_argument('--decoder-latent-layer', action='store_true', help='latent layer selection in decoder')
parser.add_argument('--target-layers', default=-1, type=int,
help='number of effective layers to learn; -1 means no constraint')
parser.add_argument('--sparsity-weight', default=0.0, type=float,
help='weight for sparsity loss')
parser.add_argument('--share-weight', default=0.0, type=float,
help='weight for sharing loss')
parser.add_argument('--soft-update', default=1, type=int,
help='number of updates with soft sampling')
parser.add_argument('--anneal-updates', default=1, type=int,
help='number of updates to anneal the KL loss weight')
parser.add_argument('--prior', default="uniform", type=str,
help='prior used for computing KL loss')
# fmt: on
def __init__(self, args, dicts, training):
super().__init__(args, dicts, training)
self.src_langs, self.tgt_langs = zip(
*[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs]
)
if self.training and self.encoder_latent_layer:
assert self.args.share_encoders
if self.training and self.decoder_latent_layer:
assert self.args.share_decoders
if training or self.encoder_latent_layer or self.decoder_latent_layer:
self.lang_pairs = args.lang_pairs
else:
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
self.eval_lang_pairs = self.lang_pairs
self.model_lang_pairs = self.lang_pairs
if self.training and (self.encoder_latent_layer or self.decoder_latent_layer):
self.kl_loss = LatentLayersKLLoss(self.args)
self.sparsity_loss = LatentLayersSparsityLoss(self.args)
def _per_lang_pair_train_loss(
self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad
):
src, tgt = lang_pair.split("-")
if self.encoder_latent_layer:
src_lang_idx = self.src_lang_idx_dict[src]
model.models[lang_pair].encoder.set_lang_idx(src_lang_idx)
model.models[lang_pair].encoder.layer_select.hard_select = (
update_num > self.args.soft_update
)
if self.decoder_latent_layer:
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
model.models[lang_pair].decoder.layer_select.hard_select = (
update_num > self.args.soft_update
)
loss, sample_size, logging_output = criterion(
model.models[lang_pair], sample[lang_pair]
)
if self.encoder_latent_layer:
none_samples = sum(
1 if x is None else 0
for x in model.models[lang_pair].encoder.layer_select.layer_samples
)
if none_samples == 0 or self.args.prior != "agged_posterior":
loss += self.kl_loss(
model.models[lang_pair].encoder.layer_select.layer_samples,
src_lang_idx,
update_num,
sample_size,
)
if self.decoder_latent_layer:
none_samples = sum(
1 if x is None else 0
for x in model.models[lang_pair].decoder.layer_select.layer_samples
)
if none_samples == 0 or self.args.prior != "agged_posterior":
loss += self.kl_loss(
model.models[lang_pair].decoder.layer_select.layer_samples,
tgt_lang_idx,
update_num,
sample_size,
)
if ignore_grad:
loss *= 0
if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num):
# need to retain the graph if sparsity loss needs to be added
loss.backward(retain_graph=True)
else:
optimizer.backward(loss)
return loss, sample_size, logging_output
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
agg_loss, agg_sample_size, agg_logging_output = super().train_step(
sample, model, criterion, optimizer, update_num, ignore_grad
)
# compute auxiliary loss from layere sparsity, based on all samples from all languages
if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num):
sparsity_loss = 0
if self.encoder_latent_layer:
sparsity_loss += self.sparsity_loss(
next(
iter(model.models.values())
).encoder.layer_select.layer_samples,
update_num,
agg_sample_size,
)
if self.decoder_latent_layer:
sparsity_loss += self.sparsity_loss(
next(
iter(model.models.values())
).decoder.layer_select.layer_samples,
update_num,
agg_sample_size,
)
if sparsity_loss > 0:
optimizer.backward(sparsity_loss)
return agg_loss, agg_sample_size, agg_logging_output
def _per_lang_pair_valid_loss(self, lang_pair, model, criterion, sample):
src, tgt = lang_pair.split("-")
if self.encoder_latent_layer:
src_lang_idx = self.src_lang_idx_dict[src]
model.models[lang_pair].encoder.set_lang_idx(src_lang_idx)
if self.decoder_latent_layer:
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
loss, sample_size, logging_output = criterion(
model.models[lang_pair], sample[lang_pair]
)
return loss, sample_size, logging_output
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
if self.encoder_latent_layer or self.decoder_latent_layer:
for model in models:
if self.encoder_latent_layer:
assert model.encoder.layer_select is not None
src_lang_idx = self.src_lang_idx_dict[self.args.source_lang]
model.encoder.set_lang_idx(src_lang_idx)
if self.decoder_latent_layer:
assert model.decoder.layer_select is not None
tgt_lang_idx = self.tgt_lang_idx_dict[self.args.target_lang]
model.decoder.set_lang_idx(tgt_lang_idx)
return super().inference_step(
generator, models, sample, prefix_tokens, constraints
)
@property
def encoder_latent_layer(self):
return (
hasattr(self.args, "encoder_latent_layer")
and self.args.encoder_latent_layer
)
@property
def decoder_latent_layer(self):
return (
hasattr(self.args, "decoder_latent_layer")
and self.args.decoder_latent_layer
)
@property
def src_lang_idx_dict(self):
return {lang: lang_idx for lang_idx, lang in enumerate(self.src_langs)}
@property
def tgt_lang_idx_dict(self):
return {lang: lang_idx for lang_idx, lang in enumerate(self.tgt_langs)}
# Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)
This page contains information for how to train models with LayerDrop, based on this [paper](https://arxiv.org/abs/1909.11556).
## Citation:
If you found this technique useful, please cite our paper:
```bibtex
@article{fan2019reducing,
title={Reducing Transformer Depth on Demand with Structured Dropout},
author={Fan, Angela and Grave, Edouard and Joulin, Armand},
journal={arXiv preprint arXiv:1909.11556},
year={2019}
}
```
## Pre-trained models
Model | Description | Download
---|---|---
`layerdrop_wmt_en_de_12_6` | Transformer + LayerDrop 0.2 trained on WMT16 en-de with 12 encoder and 6 decoder layers | [layerdrop_wmt_en_de_12_6.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/layerdrop_wmt_en_de_12_6.tar.gz)
`roberta_layerdrop.base` | RoBERTa Base + LayerDrop 0.2 | [roberta_layerdrop.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta_layerdrop.base.qnli.tar.gz)
`roberta_layerdrop.large` | RoBERTa Large + LayerDrop 0.2 | [roberta_layerdrop.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta_layerdrop.large.tar.gz)
`roberta_layerdrop.large.mnli` | `roberta_layerdrop.large` finetuned on [MNLI](http://www.nyu.edu/projects/bowman/multinli) | [roberta_layerdrop.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta_layerdrop.large.mnli.tar.gz)
`roberta_layerdrop.large.qnli` | `roberta_layerdrop.large` finetuned on [QNLI](https://arxiv.org/abs/1804.07461) | [roberta_layerdrop.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta_layerdrop.large.qnli.tar.gz)
Evaluate performance of these pre-trained models:
```bash
# Example for Machine Translation
fairseq-generate /path/to/bped/wmt/data --path nmt_checkpoint.pt \
--beam 8 --lenpen 0.4 \
--batch-size 64 \
--remove-bpe \
--gen-subset test > wmt16_gen.txt
bash scripts/compound_split_bleu.sh wmt16_gen.txt
# prints BLEU4 = 30.17
```
```python
# Example for RoBERTa + LayerDrop finetuned on MNLI:
from fairseq.models.roberta import RobertaModel
roberta_layerdrop = RobertaModel.from_pretrained(
'/path/to/MNLI/model',
checkpoint_file='mnli_checkpoint.pt',
data_name_or_path='/path/to/MNLI/data/MNLI-bin'
)
label_map = {0: 'contradiction', 2: 'neutral', 1: 'entailment'}
ncorrect, nsamples = 0, 0
roberta_layerdrop.cuda()
roberta_layerdrop.eval()
with open('/path/to/MNLI/data/dev_matched.tsv') as fin:
fin.readline()
for index, line in enumerate(fin):
tokens = line.strip().split('\t')
sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
tokens = roberta_layerdrop.encode(sent1, sent2)
prediction = roberta_layerdrop.predict('sentence_classification_head', tokens).argmax().item()
prediction_label = label_map[prediction]
ncorrect += int(prediction_label == target)
nsamples += 1
print('| Accuracy: ', float(ncorrect)/float(nsamples))
# prints | Accuracy: 0.9026999490575649
# Example for RoBERTa + LayerDrop finetuned on QNLI:
roberta = RobertaModel.from_pretrained(
'/path/to/QNLI/model',
checkpoint_file='qnli_checkpoint.pt',
data_name_or_path='/path/to/QNLI/data/QNLI-bin'
)
label_fn = lambda label: roberta.task.label_dictionary.string(
[label + roberta.task.target_dictionary.nspecial]
)
ncorrect, nsamples = 0, 0
roberta.cuda()
roberta.eval()
with open('/path/to/QNLI/data/dev.tsv') as fin:
fin.readline()
for index, line in enumerate(fin):
tokens = line.strip().split('\t')
sent1, sent2, target = tokens[1], tokens[2], tokens[3]
tokens = roberta.encode(sent1, sent2)
prediction = roberta.predict('sentence_classification_head', tokens).argmax().item()
prediction_label = label_fn(prediction)
ncorrect += int(prediction_label == target)
nsamples += 1
print('| Accuracy: ', float(ncorrect)/float(nsamples))
# prints | Accuracy: 0.9480139117700896
```
## Example usage
To train a model with LayerDrop, add the following flags. We recommend 0.2, a value that worked well in our experiments. For Language Models that are decoder-only, you need only the decoder flag. For RoBERTa, an encoder, you need only the encoder flag. The encoder and decoder LayerDrop values can be set differently.
```
--encoder-layerdrop 0.2 --decoder-layerdrop 0.2
```
To prune a model that has been trained with LayerDrop, add the following flags followed by a comma separated list of which layers you would like to keep.
```
--encoder-layers-to-keep 0,2,4,6,8,10,12,14 --decoder-layers-to-keep 0,2,4,6,8,10,12,14
```
Setting these flags should print a message such as:
```
| Pruning model to specified layer configuration
```
You should also see a smaller number of parameters in the model, for example the 16-Layer Transformer Language Model prints:
```
num. model params: 246933504
```
while a model pruned to 8 Layers prints:
```
num. model params: 146163712
```
If you would like to pick up training with a model that has been pruned, simply adding these flags is sufficient. If you would like to use a script that only does evaluation (no training), you may need to pass an override command. A specific example would be for language modeling:
```bash
fairseq-eval-lm /path/to/wikitext-103 \
--path /path/to/model/checkpoint.pt \
--model-overrides "{'decoder_layers_to_keep':'0,2,4,6,8,10,12,14'}"
```
This model override command overrides the training parameters and updates the model arguments so that the pruned model is run instead of the full model.
## Reproduce Paper Results
Looking to reproduce the results in the paper?
1. For Translation on WMT16 en-de, we followed this setting [here](https://github.com/pytorch/fairseq/blob/master/examples/scaling_nmt/README.md)
2. To train RoBERTa, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/roberta)
3. To train Language Models on Wikitext-103, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/language_model)
## Tips
1. If you would like to train large models with better performance, LayerDrop should be set to a smaller value such as 0.1 or 0.2. Too much LayerDrop will mean the model has too much regularization, so may not reach the best performance. Since LayerDrop adds regularization, you may achieve the best performance by slightly reducing the amount of standard dropout (for example, reduce by 0.1).
2. If you would like to train large models to be pruned and made smaller, LayerDrop should be set to a larger value such as 0.5 if you want to prune very aggressively (such as removing half the network or more). If you would like to prune fewer layers away, LayerDrop can be set to a smaller value such as 0.2. Our experiments were conducted with low values of LayerDrop (such as 0.1 and 0.2), for reference.
3. When pruning layers at inference time, it is best to spread out the layers remaining so they are evenly spaced throughout the network. For example, if you want to remove 50% of the network, keeping every other layer is good.
## FAQ
1. How did the sharing layers experiment work? In an appendix (https://openreview.net/pdf?id=SylO2yStDr) we added an experiment on Wikitext-103 language modeling that combined LayerDrop with Weight Sharing. We shared chunks of 2 layers such that every other layer had shared weights. For example, if our network has layers 1 through 6, then layer 1 and 2 are shared, layer 3 and 4 are shared, and layer 5 and 6 are shared.
2. LayerDrop hasn't been helping in my setting? During training time, LayerDrop can help regularize your network. This is most important if your network is already overfitting - if your network is underfitting, it is possible LayerDrop is adding too much regularization. We recommend using smaller values (such as 0.1 or 0.2) and also decreasing the quantity of standard dropout (for example, reduce by 0.1).
3. Can you train a model without LayerDrop and finetune with LayerDrop (e.g. for BERT)? In our experiments, we did not see great performance. Models such as RoBERTa have trained for a long time in the pre-training setting, so only finetuning with LayerDrop for a few epochs on a downstream task such as MNLI does not achieve the robustness required for successful pruning.
## Having an issue or have a question?
Please open an issue in this repository with the details of your question. Thanks!
# Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)
This example contains code to train Linformer models as described in our paper
[Linformer: Self-Attention with Linear Complexity](https://arxiv.org/abs/2006.04768).
## Training a new Linformer RoBERTa model
You can mostly follow the [RoBERTa pretraining README](/examples/roberta/README.pretraining.md),
updating your training command with `--user-dir examples/linformer/linformer_src --arch linformer_roberta_base`.
## Citation
If you use our work, please cite:
```bibtex
@article{wang2020linformer,
title={Linformer: Self-Attention with Linear Complexity},
author={Wang, Sinong and Li, Belinda and Khabsa, Madian and Fang, Han and Ma, Hao},
journal={arXiv preprint arXiv:2006.04768},
year={2020}
}
```
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .models import linformer_roberta # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Linformer: Self-Attention with Linear Complexity
"""
import logging
import torch
from fairseq import utils
from fairseq.models import register_model, register_model_architecture
from fairseq.models.roberta import (
init_bert_params,
roberta_base_architecture,
roberta_large_architecture,
RobertaEncoder,
RobertaModel,
)
from ..modules.linformer_sentence_encoder import LinformerTransformerEncoder
logger = logging.getLogger(__name__)
@register_model("linformer_roberta")
class LinformerModel(RobertaModel):
@staticmethod
def add_args(parser):
RobertaModel.add_args(parser)
# add args for Linformer
parser.add_argument(
"--compressed", type=int, help="compressed ratio of sequence length"
)
parser.add_argument(
"--shared-kv-compressed",
type=int,
help="share compressed matrix between k and v, in each layer",
)
parser.add_argument(
"--shared-layer-kv-compressed",
type=int,
help="share compressed matrix between k and v and across all layers",
)
parser.add_argument(
"--freeze-compress",
type=int,
help="freeze the parameters in compressed layer",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present
base_architecture(args)
if not hasattr(args, "max_positions"):
args.max_positions = args.tokens_per_sample
encoder = LinformerEncoder(args, task.source_dictionary)
return cls(args, encoder)
class LinformerEncoder(RobertaEncoder):
"""Linformer encoder."""
def __init__(self, args, dictionary):
super().__init__(args, dictionary)
self.register_buffer("version", torch.tensor(2))
def build_encoder(self, args, dictionary, embed_tokens):
encoder = LinformerTransformerEncoder(args, dictionary, embed_tokens)
encoder.apply(init_bert_params)
return encoder
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
prefix = name + "." if name != "" else ""
# some old checkpoints had weight sharing implemented incorrectly
# (note: this was correct in the original paper code)
if utils.item(state_dict.get(f"{prefix}version", torch.tensor(1))) < 2:
state_dict[f"{prefix}version"] = torch.tensor(1)
# check if input embeddings and output embeddings were tied
if not torch.allclose(
state_dict[f"{prefix}sentence_encoder.embed_tokens.weight"],
state_dict[f"{prefix}lm_head.weight"],
):
# they weren't tied, re-init the LM head without weight sharing
self.lm_head = self.build_lm_head(
embed_dim=self.args.encoder_embed_dim,
output_dim=len(self.dictionary),
activation_fn=self.args.activation_fn,
weight=None, # don't share weights
)
@register_model_architecture("linformer_roberta", "linformer_roberta")
def base_architecture(args):
args.compressed = getattr(args, "compressed", 4)
args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0)
args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0)
args.freeze_compress = getattr(args, "freeze_compress", 0)
roberta_base_architecture(args)
@register_model_architecture("linformer_roberta", "linformer_roberta_base")
def linformer_roberta_base_architecture(args):
base_architecture(args)
@register_model_architecture("linformer_roberta", "linformer_roberta_large")
def linformer_roberta_large_architecture(args):
roberta_large_architecture(args)
base_architecture(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment