Commit 7143f128 authored by sunxx1's avatar sunxx1
Browse files

Merge branch 'hepj-test' into 'main'

更新transformer代码

See merge request dcutoolkit/deeplearing/dlexamples_new!47
parents a30b77fe c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import argparse
import importlib
import os
from contextlib import ExitStack
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import merge_with_parent
from hydra.core.config_store import ConfigStore
from omegaconf import open_dict, OmegaConf
from .composite_encoder import CompositeEncoder
from .distributed_fairseq_model import DistributedFairseqModel
from .fairseq_decoder import FairseqDecoder
from .fairseq_encoder import FairseqEncoder
from .fairseq_incremental_decoder import FairseqIncrementalDecoder
from .fairseq_model import (
BaseFairseqModel,
FairseqEncoderDecoderModel,
FairseqEncoderModel,
FairseqLanguageModel,
FairseqModel,
FairseqMultiModel,
)
MODEL_REGISTRY = {}
MODEL_DATACLASS_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_MODEL_NAME_REGISTRY = {}
ARCH_MODEL_INV_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
__all__ = [
"BaseFairseqModel",
"CompositeEncoder",
"DistributedFairseqModel",
"FairseqDecoder",
"FairseqEncoder",
"FairseqEncoderDecoderModel",
"FairseqEncoderModel",
"FairseqIncrementalDecoder",
"FairseqLanguageModel",
"FairseqModel",
"FairseqMultiModel",
]
def build_model(cfg: FairseqDataclass, task, from_checkpoint=False):
model = None
model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None)
if not model_type and len(cfg) == 1:
# this is hit if config object is nested in directory that is named after model type
model_type = next(iter(cfg))
if model_type in MODEL_DATACLASS_REGISTRY:
cfg = cfg[model_type]
else:
raise Exception(
"Could not infer model type from directory. Please add _name field to indicate model type. "
"Available models: "
+ str(MODEL_DATACLASS_REGISTRY.keys())
+ " Requested model type: "
+ model_type
)
if model_type in ARCH_MODEL_REGISTRY:
# case 1: legacy models
model = ARCH_MODEL_REGISTRY[model_type]
elif model_type in MODEL_DATACLASS_REGISTRY:
# case 2: config-driven models
model = MODEL_REGISTRY[model_type]
if model_type in MODEL_DATACLASS_REGISTRY:
# set defaults from dataclass. note that arch name and model name can be the same
dc = MODEL_DATACLASS_REGISTRY[model_type]
if isinstance(cfg, argparse.Namespace):
cfg = dc.from_namespace(cfg)
else:
cfg = merge_with_parent(dc(), cfg, from_checkpoint)
else:
if model_type in ARCH_CONFIG_REGISTRY:
with open_dict(cfg) if OmegaConf.is_config(cfg) else ExitStack():
# this calls the different "arch" functions (like base_architecture()) that you indicate
# if you specify --arch on the command line. this is only applicable to the old argparse based models
# hydra models should expose different architectures via different config files
# it will modify the cfg object and default parameters according to the arch
ARCH_CONFIG_REGISTRY[model_type](cfg)
assert model is not None, (
f"Could not infer model type from {cfg}. "
"Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys())
+ f" Requested model type: {model_type}"
)
return model.build_model(cfg, task)
def register_model(name, dataclass=None):
"""
New model types can be added to fairseq with the :func:`register_model`
function decorator.
For example::
@register_model('lstm')
class LSTM(FairseqEncoderDecoderModel):
(...)
.. note:: All models must implement the :class:`BaseFairseqModel` interface.
Typically you will extend :class:`FairseqEncoderDecoderModel` for
sequence-to-sequence tasks or :class:`FairseqLanguageModel` for
language modeling tasks.
Args:
name (str): the name of the model
"""
def register_model_cls(cls):
if name in MODEL_REGISTRY:
raise ValueError("Cannot register duplicate model ({})".format(name))
if not issubclass(cls, BaseFairseqModel):
raise ValueError(
"Model ({}: {}) must extend BaseFairseqModel".format(name, cls.__name__)
)
MODEL_REGISTRY[name] = cls
if dataclass is not None and not issubclass(dataclass, FairseqDataclass):
raise ValueError(
"Dataclass {} must extend FairseqDataclass".format(dataclass)
)
cls.__dataclass = dataclass
if dataclass is not None:
MODEL_DATACLASS_REGISTRY[name] = dataclass
cs = ConfigStore.instance()
node = dataclass()
node._name = name
cs.store(name=name, group="model", node=node, provider="fairseq")
@register_model_architecture(name, name)
def noop(_):
pass
return cls
return register_model_cls
def register_model_architecture(model_name, arch_name):
"""
New model architectures can be added to fairseq with the
:func:`register_model_architecture` function decorator. After registration,
model architectures can be selected with the ``--arch`` command-line
argument.
For example::
@register_model_architecture('lstm', 'lstm_luong_wmt_en_de')
def lstm_luong_wmt_en_de(cfg):
args.encoder_embed_dim = getattr(cfg.model, 'encoder_embed_dim', 1000)
(...)
The decorated function should take a single argument *cfg*, which is a
:class:`omegaconf.DictConfig`. The decorated function should modify these
arguments in-place to match the desired architecture.
Args:
model_name (str): the name of the Model (Model must already be
registered)
arch_name (str): the name of the model architecture (``--arch``)
"""
def register_model_arch_fn(fn):
if model_name not in MODEL_REGISTRY:
raise ValueError(
"Cannot register model architecture for unknown model type ({})".format(
model_name
)
)
if arch_name in ARCH_MODEL_REGISTRY:
raise ValueError(
"Cannot register duplicate model architecture ({})".format(arch_name)
)
if not callable(fn):
raise ValueError(
"Model architecture must be callable ({})".format(arch_name)
)
ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name]
ARCH_MODEL_NAME_REGISTRY[arch_name] = model_name
ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name)
ARCH_CONFIG_REGISTRY[arch_name] = fn
return fn
return register_model_arch_fn
def import_models(models_dir, namespace):
for file in os.listdir(models_dir):
path = os.path.join(models_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
model_name = file[: file.find(".py")] if file.endswith(".py") else file
importlib.import_module(namespace + "." + model_name)
# extra `model_parser` for sphinx
if model_name in MODEL_REGISTRY:
parser = argparse.ArgumentParser(add_help=False)
group_archs = parser.add_argument_group("Named architectures")
group_archs.add_argument(
"--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name]
)
group_args = parser.add_argument_group(
"Additional command-line arguments"
)
MODEL_REGISTRY[model_name].add_args(group_args)
globals()[model_name + "_parser"] = parser
# automatically import any Python files in the models/ directory
models_dir = os.path.dirname(__file__)
import_models(models_dir, "fairseq.models")
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .hub_interface import * # noqa
from .model import * # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
from typing import Dict, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import encoders
from fairseq.hub_utils import GeneratorHubInterface
from omegaconf import open_dict
logger = logging.getLogger(__name__)
class BARTHubInterface(GeneratorHubInterface):
"""A simple PyTorch Hub interface to BART.
Usage: https://github.com/pytorch/fairseq/tree/main/examples/bart
"""
def __init__(self, cfg, task, model):
super().__init__(cfg, task, [model])
self.model = self.models[0]
def encode(
self, sentence: str, *addl_sentences, no_separator=True
) -> torch.LongTensor:
"""
BPE-encode a sentence (or multiple sentences).
Every sequence begins with a beginning-of-sentence (`<s>`) symbol.
Every sentence ends with an end-of-sentence (`</s>`).
Example (single sentence): `<s> a b c </s>`
Example (sentence pair): `<s> d e f </s> 1 2 3 </s>`
The BPE encoding follows GPT-2. One subtle detail is that the GPT-2 BPE
requires leading spaces. For example::
>>> bart.encode('Hello world').tolist()
[0, 31414, 232, 2]
>>> bart.encode(' world').tolist()
[0, 232, 2]
>>> bart.encode('world').tolist()
[0, 8331, 2]
"""
tokens = self.bpe.encode(sentence)
if len(tokens.split(" ")) > min(self.max_positions) - 2:
tokens = " ".join(tokens.split(" ")[: min(self.max_positions) - 2])
bpe_sentence = "<s> " + tokens + " </s>"
for s in addl_sentences:
bpe_sentence += " </s>" if not no_separator else ""
bpe_sentence += " " + self.bpe.encode(s) + " </s>"
tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False)
return tokens.long()
def decode(self, tokens: torch.LongTensor):
assert tokens.dim() == 1
tokens = tokens.cpu().numpy()
if tokens[0] == self.task.source_dictionary.bos():
tokens = tokens[1:] # remove <s>
eos_mask = tokens == self.task.source_dictionary.eos()
doc_mask = eos_mask[1:] & eos_mask[:-1]
sentences = np.split(tokens, doc_mask.nonzero()[0] + 1)
sentences = [
self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences
]
if len(sentences) == 1:
return sentences[0]
return sentences
def _build_sample(self, src_tokens: List[torch.LongTensor]):
# assert torch.is_tensor(src_tokens)
dataset = self.task.build_dataset_for_inference(
src_tokens,
[x.numel() for x in src_tokens],
)
sample = dataset.collater(dataset)
sample = utils.apply_to_sample(lambda tensor: tensor.to(self.device), sample)
return sample
def generate(
self,
tokenized_sentences: List[torch.LongTensor],
*args,
inference_step_args=None,
skip_invalid_size_inputs=False,
**kwargs
) -> List[List[Dict[str, torch.Tensor]]]:
inference_step_args = inference_step_args or {}
if "prefix_tokens" in inference_step_args:
raise NotImplementedError("prefix generation not implemented for BART")
res = []
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
src_tokens = batch["net_input"]["src_tokens"]
inference_step_args["prefix_tokens"] = src_tokens.new_full(
(src_tokens.size(0), 1), fill_value=self.task.source_dictionary.bos()
).to(device=self.device)
results = super().generate(
src_tokens,
*args,
inference_step_args=inference_step_args,
skip_invalid_size_inputs=skip_invalid_size_inputs,
**kwargs
)
for id, hypos in zip(batch["id"].tolist(), results):
res.append((id, hypos))
res = [hypos for _, hypos in sorted(res, key=lambda x: x[0])]
return res
def extract_features(
self, tokens: torch.LongTensor, return_all_hiddens: bool = False
) -> torch.Tensor:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
if tokens.size(-1) > min(self.model.max_positions()):
raise ValueError(
"tokens exceeds maximum length: {} > {}".format(
tokens.size(-1), self.model.max_positions()
)
)
tokens.to(device=self.device),
prev_output_tokens = tokens.clone()
prev_output_tokens[:, 0] = tokens.gather(
1,
(tokens.ne(self.task.source_dictionary.pad()).sum(dim=1) - 1).unsqueeze(-1),
).squeeze()
prev_output_tokens[:, 1:] = tokens[:, :-1]
features, extra = self.model(
src_tokens=tokens,
src_lengths=None,
prev_output_tokens=prev_output_tokens,
features_only=True,
return_all_hiddens=return_all_hiddens,
)
if return_all_hiddens:
# convert from T x B x C -> B x T x C
inner_states = extra["inner_states"]
return [inner_state.transpose(0, 1) for inner_state in inner_states]
else:
return features # just the last layer's features
def register_classification_head(
self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs
):
self.model.register_classification_head(
name, num_classes=num_classes, embedding_size=embedding_size, **kwargs
)
def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False):
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
features = self.extract_features(tokens.to(device=self.device))
sentence_representation = features[
tokens.eq(self.task.source_dictionary.eos()), :
].view(features.size(0), -1, features.size(-1))[:, -1, :]
logits = self.model.classification_heads[head](sentence_representation)
if return_logits:
return logits
return F.log_softmax(logits, dim=-1)
def fill_mask(
self,
masked_inputs: List[str],
topk: int = 5,
match_source_len: bool = True,
**generate_kwargs
):
masked_token = "<mask>"
batch_tokens = []
for masked_input in masked_inputs:
assert (
masked_token in masked_input
), "please add one {} token for the input".format(masked_token)
text_spans = masked_input.split(masked_token)
text_spans_bpe = (
(" {0} ".format(masked_token))
.join([self.bpe.encode(text_span.rstrip()) for text_span in text_spans])
.strip()
)
tokens = self.task.source_dictionary.encode_line(
"<s> " + text_spans_bpe + " </s>",
append_eos=False,
add_if_not_exist=False,
).long()
batch_tokens.append(tokens)
# ensure beam size is at least as big as topk
generate_kwargs["beam"] = max(
topk,
generate_kwargs.get("beam", -1),
)
generate_kwargs["match_source_len"] = match_source_len
batch_hypos = self.generate(batch_tokens, **generate_kwargs)
return [
[(self.decode(hypo["tokens"]), hypo["score"]) for hypo in hypos[:topk]]
for hypos in batch_hypos
]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
BART: Denoising Sequence-to-Sequence Pre-training for
Natural Language Generation, Translation, and Comprehension
"""
import logging
from typing import Optional
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import TransformerModel
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from .hub_interface import BARTHubInterface
logger = logging.getLogger(__name__)
@register_model("bart")
class BARTModel(TransformerModel):
__jit_unused_properties__ = ["supported_targets"]
@classmethod
def hub_models(cls):
return {
"bart.base": "http://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz",
"bart.large": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz",
"bart.large.mnli": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz",
"bart.large.cnn": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz",
"bart.large.xsum": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz",
}
def __init__(self, args, encoder, decoder):
super().__init__(args, encoder, decoder)
# We follow BERT's random weight initialization
self.apply(init_bert_params)
self.classification_heads = nn.ModuleDict()
if hasattr(self.encoder, "dictionary"):
self.eos: int = self.encoder.dictionary.eos()
@staticmethod
def add_args(parser):
super(BARTModel, BARTModel).add_args(parser)
parser.add_argument(
"--pooler-dropout",
type=float,
metavar="D",
help="dropout probability in the masked_lm pooler layers",
)
parser.add_argument(
"--pooler-activation-fn",
choices=utils.get_available_activation_fns(),
help="activation function to use for pooler layer",
)
parser.add_argument(
"--spectral-norm-classification-head",
action="store_true",
help="Apply spectral normalization on the classification head",
)
@property
def supported_targets(self):
return {"self"}
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens,
features_only: bool = False,
classification_head_name: Optional[str] = None,
token_embeddings: Optional[torch.Tensor] = None,
return_all_hiddens: bool = True,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
if classification_head_name is not None:
features_only = True
encoder_out = self.encoder(
src_tokens,
src_lengths=src_lengths,
token_embeddings=token_embeddings,
return_all_hiddens=return_all_hiddens,
)
x, extra = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
features_only=features_only,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
src_lengths=src_lengths,
return_all_hiddens=return_all_hiddens,
)
eos: int = self.eos
if classification_head_name is not None:
sentence_representation = x[src_tokens.eq(eos), :].view(
x.size(0), -1, x.size(-1)
)[:, -1, :]
for k, head in self.classification_heads.items():
# for torch script only supports iteration
if k == classification_head_name:
x = head(sentence_representation)
break
return x, extra
@classmethod
def from_pretrained(
cls,
model_name_or_path,
checkpoint_file="model.pt",
data_name_or_path=".",
bpe="gpt2",
sample_break_mode="eos",
**kwargs,
):
from fairseq import hub_utils
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
bpe=bpe,
load_checkpoint_heads=True,
sample_break_mode=sample_break_mode,
**kwargs,
)
return BARTHubInterface(x["args"], x["task"], x["models"][0])
def register_classification_head(
self, name, num_classes=None, inner_dim=None, **kwargs
):
"""Register a classification head."""
logger.info("Registering classification head: {0}".format(name))
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
self.classification_heads[name] = BARTClassificationHead(
input_dim=self.args.encoder_embed_dim,
inner_dim=inner_dim or self.args.encoder_embed_dim,
num_classes=num_classes,
activation_fn=self.args.pooler_activation_fn,
pooler_dropout=self.args.pooler_dropout,
do_spectral_norm=getattr(
self.args, "spectral_norm_classification_head", False
),
)
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
prefix = name + "." if name != "" else ""
current_head_names = (
[]
if not hasattr(self, "classification_heads")
else self.classification_heads.keys()
)
# Handle new classification heads present in the state dict.
keys_to_delete = []
for k in state_dict.keys():
if not k.startswith(prefix + "classification_heads."):
continue
head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
num_classes = state_dict[
prefix + "classification_heads." + head_name + ".out_proj.weight"
].size(0)
inner_dim = state_dict[
prefix + "classification_heads." + head_name + ".dense.weight"
].size(0)
if getattr(self.args, "load_checkpoint_heads", False):
if head_name not in current_head_names:
self.register_classification_head(head_name, num_classes, inner_dim)
else:
if head_name not in current_head_names:
logger.warning(
"deleting classification head ({}) from checkpoint "
"not present in current model: {}".format(head_name, k)
)
keys_to_delete.append(k)
elif (
num_classes
!= self.classification_heads[head_name].out_proj.out_features
or inner_dim
!= self.classification_heads[head_name].dense.out_features
):
logger.warning(
"deleting classification head ({}) from checkpoint "
"with different dimensions than current model: {}".format(
head_name, k
)
)
keys_to_delete.append(k)
for k in keys_to_delete:
del state_dict[k]
def truncate_emb(key):
if key in state_dict:
state_dict[key] = state_dict[key][:-1, :]
# When finetuning on translation task, remove last row of
# embedding matrix that corresponds to mask_idx token.
loaded_dict_size = state_dict["encoder.embed_tokens.weight"].size(0)
if (
loaded_dict_size == len(self.encoder.dictionary) + 1
and "<mask>" not in self.encoder.dictionary
):
truncate_emb("encoder.embed_tokens.weight")
truncate_emb("decoder.embed_tokens.weight")
truncate_emb("encoder.output_projection.weight")
truncate_emb("decoder.output_projection.weight")
# When continued pretraining on new set of languages for mbart,
# add extra lang embeddings at the end of embed_tokens.
# Note: newly added languages are assumed to have been added at the end.
if self.args.task == "multilingual_denoising" and loaded_dict_size < len(
self.encoder.dictionary
):
logger.info(
"Adding extra language embeddings not found in pretrained model for "
"continued pretraining of MBART on new set of languages."
)
loaded_mask_token_embedding = state_dict["encoder.embed_tokens.weight"][
-1, :
]
num_langids_to_add = len(self.encoder.dictionary) - loaded_dict_size
embed_dim = state_dict["encoder.embed_tokens.weight"].size(1)
new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim)
nn.init.normal_(new_lang_embed_to_add, mean=0, std=embed_dim**-0.5)
new_lang_embed_to_add = new_lang_embed_to_add.to(
dtype=state_dict["encoder.embed_tokens.weight"].dtype,
)
state_dict["encoder.embed_tokens.weight"] = torch.cat(
[
state_dict["encoder.embed_tokens.weight"][
: loaded_dict_size - 1, :
],
new_lang_embed_to_add,
loaded_mask_token_embedding.unsqueeze(0),
]
)
state_dict["decoder.embed_tokens.weight"] = torch.cat(
[
state_dict["decoder.embed_tokens.weight"][
: loaded_dict_size - 1, :
],
new_lang_embed_to_add,
loaded_mask_token_embedding.unsqueeze(0),
]
)
# Copy any newly-added classification heads into the state dict
# with their current weights.
if hasattr(self, "classification_heads"):
cur_state = self.classification_heads.state_dict()
for k, v in cur_state.items():
if prefix + "classification_heads." + k not in state_dict:
logger.info("Overwriting " + prefix + "classification_heads." + k)
state_dict[prefix + "classification_heads." + k] = v
def set_beam_size(self, beam):
"""Set beam size for efficient beamable enc-dec attention."""
beamable = False
for layer in self.decoder.layers:
if layer.encoder_attn is not None:
if hasattr(layer.encoder_attn, "set_beam_size"):
layer.encoder_attn.set_beam_size(beam)
beamable = True
if beamable:
self.encoder.reorder_encoder_out = self.encoder._reorder_encoder_out
class BARTClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(
self,
input_dim,
inner_dim,
num_classes,
activation_fn,
pooler_dropout,
do_spectral_norm=False,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
if do_spectral_norm:
self.out_proj = torch.nn.utils.spectral_norm(self.out_proj)
def forward(self, features, **kwargs):
x = features
x = self.dropout(x)
x = self.dense(x)
x = self.activation_fn(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
@register_model_architecture("bart", "bart_large")
def bart_large_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 1024)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 12)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.relu_dropout = getattr(args, "relu_dropout", 0.0)
args.dropout = getattr(args, "dropout", 0.1)
args.max_target_positions = getattr(args, "max_target_positions", 1024)
args.max_source_positions = getattr(args, "max_source_positions", 1024)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", True
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", True)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
@register_model_architecture("bart", "bart_base")
def bart_base_architecture(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 768)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
bart_large_architecture(args)
@register_model_architecture("bart", "mbart_large")
def mbart_large_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
bart_large_architecture(args)
@register_model_architecture("bart", "mbart_base")
def mbart_base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
bart_base_architecture(args)
@register_model_architecture("bart", "mbart_base_wmt20")
def mbart_base_wmt20_architecture(args):
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
mbart_base_architecture(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .fairseq_encoder import FairseqEncoder
class CompositeEncoder(FairseqEncoder):
"""
A wrapper around a dictionary of :class:`FairseqEncoder` objects.
We run forward on each encoder and return a dictionary of outputs. The first
encoder's dictionary is used for initialization.
Args:
encoders (dict): a dictionary of :class:`FairseqEncoder` objects.
"""
def __init__(self, encoders):
super().__init__(next(iter(encoders.values())).dictionary)
self.encoders = encoders
for key in self.encoders:
self.add_module(key, self.encoders[key])
def forward(self, src_tokens, src_lengths):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (LongTensor): lengths of each source sentence of shape
`(batch)`
Returns:
dict:
the outputs from each Encoder
"""
encoder_out = {}
for key in self.encoders:
encoder_out[key] = self.encoders[key](src_tokens, src_lengths)
return encoder_out
def reorder_encoder_out(self, encoder_out, new_order):
"""Reorder encoder output according to new_order."""
for key in self.encoders:
encoder_out[key] = self.encoders[key].reorder_encoder_out(
encoder_out[key], new_order
)
return encoder_out
def max_positions(self):
return min(self.encoders[key].max_positions() for key in self.encoders)
def upgrade_state_dict(self, state_dict):
for key in self.encoders:
self.encoders[key].upgrade_state_dict(state_dict)
return state_dict
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import signal
import threading
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from fairseq.distributed import (
DistributedTimeoutWrapper,
LegacyDistributedDataParallel,
ModuleProxyWrapper,
TPUDistributedDataParallel,
)
logger = logging.getLogger(__name__)
_SLOWMO_DDP_DISABLED = False
try:
from fairscale.experimental.nn.data_parallel import (
SlowMoBaseAlgorithm,
SlowMoDistributedDataParallel,
)
except ImportError:
_SLOWMO_DDP_DISABLED = True
def DistributedFairseqModel(args, model, process_group, device):
"""
Wrap a *model* to support distributed data parallel training.
This is similar to the built-in DistributedDataParallel, but allows
additional configuration of the DistributedDataParallel class to
use, and also provides easier access to the wrapped model by
forwarding requests for missing attributes to the wrapped model.
Args:
args (argparse.Namespace): fairseq args
model (BaseFairseqModel): model to wrap
process_group: the c10d process group to be used for distributed data
parallel all-reduction.
device: device to move model to
"""
assert isinstance(model, nn.Module)
if args.tpu:
wrapped_model = TPUDistributedDataParallel(
module=model.to(device),
process_group=process_group,
)
# forward missing getattr and state_dict/load_state_dict to orig model
wrapped_model = ModuleProxyWrapper(wrapped_model)
elif args.ddp_backend in {"c10d", "pytorch_ddp"}:
wrapped_model = DistributedDataParallel(
module=model.to(device),
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=args.broadcast_buffers,
bucket_cap_mb=args.bucket_cap_mb,
process_group=process_group,
find_unused_parameters=args.find_unused_parameters,
gradient_as_bucket_view=args.gradient_as_bucket_view,
)
if args.ddp_comm_hook == "fp16":
logger.info("enable fp16 communication hook in DDP")
try:
from torch.distributed.algorithms.ddp_comm_hooks import (
DDPCommHookType,
register_ddp_comm_hook,
)
except:
logger.error(
"Could not import from torch.distributed.algorithms.ddp_comm_hooks; you may need to update your pytorch version"
)
raise
register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, wrapped_model)
# forward missing getattr and state_dict/load_state_dict to orig model
wrapped_model = ModuleProxyWrapper(wrapped_model)
elif args.ddp_backend in {"no_c10d", "legacy_ddp"}:
wrapped_model = LegacyDistributedDataParallel(
module=model.to(device),
buffer_size=2**28,
process_group=process_group,
)
# forward missing getattr and state_dict/load_state_dict to orig model
wrapped_model = ModuleProxyWrapper(wrapped_model)
elif args.ddp_backend == "slowmo":
if _SLOWMO_DDP_DISABLED:
raise ImportError(
"Cannot find SlowMoDistributedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
# The values of slowmo_momentum below were obtained by tuning on the
# En-De 16 dataset by training the transformer_wmt_en_de_large model
if args.slowmo_momentum is None:
if args.distributed_world_size <= 16:
args.slowmo_momentum = 0.0
elif args.distributed_world_size <= 32:
args.slowmo_momentum = 0.2
elif args.distributed_world_size <= 64:
args.slowmo_momentum = 0.5
else:
args.slowmo_momentum = 0.6
slowmo_base_algorithm = SlowMoBaseAlgorithm[args.slowmo_base_algorithm.upper()]
wrapped_model = SlowMoDistributedDataParallel(
module=model.to(device),
broadcast_buffers=args.broadcast_buffers,
nprocs_per_node=args.nprocs_per_node,
slowmo_momentum=args.slowmo_momentum,
slowmo_base_algorithm=slowmo_base_algorithm,
localsgd_frequency=args.localsgd_frequency,
)
# forward missing getattr and state_dict/load_state_dict to orig model
wrapped_model = ModuleProxyWrapper(wrapped_model)
elif args.ddp_backend == "fully_sharded":
try:
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
except ImportError:
raise ImportError(
"Cannot find FullyShardedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
assert isinstance(model, FSDP), "expected model to already be wrapped in FSDP"
wrapped_model = model
if args.memory_efficient_fp16:
wrapped_model = wrapped_model.half()
if not args.cpu_offload:
wrapped_model = wrapped_model.to(device=device)
else:
raise ValueError("Unknown --ddp-backend: " + args.ddp_backend)
# kill hung distributed jobs after a timeout
if getattr(args, "heartbeat_timeout", -1) > 0:
wrapped_model = DistributedTimeoutWrapper(
wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1)
)
return wrapped_model
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
from .ema import EMA
def build_ema(model, cfg, device):
return EMA(model, cfg, device)
# automatically import any Python files in the models/ema/ directory
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
file_name = file[: file.find(".py")]
importlib.import_module("fairseq.models.ema." + file_name)
#!/usr/bin/env python3
"""
This module has the EMA class used to store a copy of the exponentially decayed
model params.
Typical usage of EMA class involves initializing an object using an existing
model (random or from a seed model) and setting the config like ema_decay,
ema_start_update which determine how the EMA model is updated. After every
update of the model i.e. at the end of the train_step, the EMA should be updated
by passing the new model to the EMA.step function. The EMA model state dict
can be stored in the extra state under the key of "ema" and dumped
into a checkpoint and loaded. The EMA object can be passed to tasks
by setting task.uses_ema property.
EMA is a smoothed/ensemble model which might have better performance
when used for inference or further fine-tuning. EMA class has a
reverse function to load the EMA params into a model and use it
like a regular model.
This implementation is used for trainer-level ema tracking. For EMA tracking
inside the model, please use fairseq/modules/ema_module.py instead.
"""
import copy
import logging
import torch
from fairseq import checkpoint_utils
class EMA(object):
"""Exponential Moving Average of Fairseq Models
EMA keeps a copy of the exponentially decayed model params.
The set of params should include both gradient-descent and
non-gradient descent params, such as batch mean/var and buffers.
This is a modified implementation of
the open source code in https://github.com/zhawe01/fairseq-gec.git,
and internal source code in
fbcode/mobile-vision/projects/classification_pytorch/lib/utils/model_ema.py.
Similar to TF EMA.
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage.
EMA provides a averaged and smoothed set of model weights, and has been shown to
improve vision models. EMA class does all necessary functions to update, reload,
or init EMA methods.
EMA object is initialized from an arbitrary model. By default, it is stored in
the same device (unless device specified at initialization) and with the
same precision as the model (unless ema_fp32 is True). ema_fp32 is recommended.
This stores the EMA parameters in fp32 only for the EMA update step, and
is used at the default precision otherwise.
EMA is usually enabled using EMAConfig with store_ema=True. Some important
parameters to configure EMA are
1) ema_decay - The decay of EMA
2) ema_update_freq - EMA is updated every this many model updates.
3) ema_start_update - Start EMA update after this many model updates [default 0]
Key methods:
1) step - One update of EMA using new model
2) restore - Update EMA from a state dict
3) reverse - Load EMA into a model
4) get_decay, _set_decay - Used to get or set the decay. Note _set_decay is
called from step.
5) build_fp32_params - Used to initialize or update the fp32 copy of EMA params.
Note this is enabled only when ema_fp32=True
"""
def __init__(self, model, config, device=None, skip_keys=None):
"""
@param model model to initialize the EMA with
@param config EMAConfig object with configuration like
ema_decay, ema_update_freq, ema_fp32
@param device If provided, copy EMA to this device (e.g. gpu).
Otherwise EMA is in the same device as the model.
"""
self.decay = config.ema_decay
self.model = copy.deepcopy(model)
self.model.requires_grad_(False)
self.config = config
self.skip_keys = skip_keys or set()
self.fp32_params = {}
if self.config.ema_seed_model is not None:
state = checkpoint_utils.load_ema_from_checkpoint(
self.config.ema_seed_model
)
self.model.load_state_dict(state["model"], strict=True)
if device is not None:
logging.info(f"Copying EMA model to device {device}")
self.model = self.model.to(device=device)
if self.config.ema_fp32:
self.build_fp32_params()
self.update_freq_counter = 0
def get_model(self):
return self.model
def build_fp32_params(self, state_dict=None):
"""
Store a copy of the EMA params in fp32.
If state dict is passed, the EMA params is copied from
the provided state dict. Otherwise, it is copied from the
current EMA model parameters.
"""
if not self.config.ema_fp32:
raise RuntimeError(
"build_fp32_params should not be called if ema_fp32=False. "
"Use ema_fp32=True if this is really intended."
)
if state_dict is None:
state_dict = self.model.state_dict()
def _to_float(t):
return t.float() if torch.is_floating_point(t) else t
for param_key in state_dict:
if param_key in self.fp32_params:
self.fp32_params[param_key].copy_(state_dict[param_key])
else:
self.fp32_params[param_key] = _to_float(state_dict[param_key])
def restore(self, state_dict, build_fp32_params=False):
"""Load data from a model spec into EMA model"""
self.model.load_state_dict(state_dict, strict=False)
if build_fp32_params:
self.build_fp32_params(state_dict)
def _set_decay(self, decay):
self.decay = decay
def get_decay(self):
return self.decay
def _step_internal(self, new_model, updates=None):
"""One update of the EMA model based on new model weights"""
decay = self.decay
ema_state_dict = {}
ema_params = (
self.fp32_params if self.config.ema_fp32 else self.model.state_dict()
)
for key, param in new_model.state_dict().items():
if isinstance(param, dict):
continue
try:
ema_param = ema_params[key]
except KeyError:
ema_param = (
param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
)
if param.shape != ema_param.shape:
raise ValueError(
"incompatible tensor shapes between model param and ema param"
+ "{} vs. {}".format(param.shape, ema_param.shape)
)
if "version" in key:
# Do not decay a model.version pytorch param
continue
if key in self.skip_keys:
ema_param = param.to(dtype=ema_param.dtype).clone()
else:
ema_param.mul_(decay)
ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay)
ema_state_dict[key] = ema_param
self.restore(ema_state_dict, build_fp32_params=False)
def step(self, new_model, updates=None):
"""
One update of EMA which is done every self.config.ema_update_freq
updates of the model.
@param updates The current number of model updates done.
Decay is set of 0 if model updates < ema_start_update, which means
the model will be simply copied over to the EMA.
When model updates >= ema_start_updates, then EMA is updated with
a decay of self.config.ema_decay.
"""
if updates is not None:
self._set_decay(
0 if updates < self.config.ema_start_update else self.config.ema_decay
)
if self.config.ema_update_freq > 1:
self.update_freq_counter += 1
if self.update_freq_counter >= self.config.ema_update_freq:
self._step_internal(new_model, updates)
self.update_freq_counter = 0
else:
self._step_internal(new_model, updates)
def reverse(self, model):
"""
Load the model parameters from EMA model.
Useful for inference or fine-tuning from the EMA model.
"""
d = self.model.state_dict()
if "_ema" in d:
del d["_ema"]
model.load_state_dict(d, strict=False)
return model
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Tuple
import torch.nn as nn
from fairseq import utils
from torch import Tensor
class FairseqDecoder(nn.Module):
"""Base class for decoders."""
def __init__(self, dictionary):
super().__init__()
self.dictionary = dictionary
self.onnx_trace = False
self.adaptive_softmax = None
def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
"""
Args:
prev_output_tokens (LongTensor): shifted output tokens of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (dict, optional): output from the encoder, used for
encoder-side attention
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(
prev_output_tokens, encoder_out=encoder_out, **kwargs
)
x = self.output_layer(x)
return x, extra
def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs):
"""
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
raise NotImplementedError
def output_layer(self, features, **kwargs):
"""
Project features to the default output size, e.g., vocabulary size.
Args:
features (Tensor): features returned by *extract_features*.
"""
raise NotImplementedError
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def get_normalized_probs_scriptable(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
if sample is not None:
assert "target" in sample
target = sample["target"]
else:
target = None
out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
return out.exp_() if not log_probs else out
logits = net_output[0]
if log_probs:
return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
else:
return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def max_positions(self):
"""Maximum input length supported by the decoder."""
return 1e6 # an arbitrary large number
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade old state dicts to work with newer code."""
return state_dict
def prepare_for_onnx_export_(self):
self.onnx_trace = True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, NamedTuple, Optional
import torch
import torch.nn as nn
from torch import Tensor
EncoderOut = NamedTuple(
"EncoderOut",
[
("encoder_out", Tensor), # T x B x C
("encoder_padding_mask", Optional[Tensor]), # B x T
("encoder_embedding", Optional[Tensor]), # B x T x C
("encoder_states", Optional[List[Tensor]]), # List[T x B x C]
("src_tokens", Optional[Tensor]), # B x T
("src_lengths", Optional[Tensor]), # B x 1
],
)
class FairseqEncoder(nn.Module):
"""Base class for encoders."""
def __init__(self, dictionary):
super().__init__()
self.dictionary = dictionary
def forward(self, src_tokens, src_lengths=None, **kwargs):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (LongTensor): lengths of each source sentence of shape
`(batch)`
"""
raise NotImplementedError
def forward_torchscript(self, net_input: Dict[str, Tensor]):
"""A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override
this method for TorchScript compatibility.
"""
if torch.jit.is_scripting():
return self.forward(
src_tokens=net_input["src_tokens"],
src_lengths=net_input["src_lengths"],
)
else:
return self.forward_non_torchscript(net_input)
@torch.jit.unused
def forward_non_torchscript(self, net_input: Dict[str, Tensor]):
encoder_input = {
k: v for k, v in net_input.items() if k != "prev_output_tokens"
}
return self.forward(**encoder_input)
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to `new_order`.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
`encoder_out` rearranged according to `new_order`
"""
raise NotImplementedError
def max_positions(self):
"""Maximum input length supported by the encoder."""
return 1e6 # an arbitrary large number
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade old state dicts to work with newer code."""
return state_dict
def set_num_updates(self, num_updates):
"""State from trainer to pass along to model at every update."""
def _apply(m):
if hasattr(m, "set_num_updates") and m != self:
m.set_num_updates(num_updates)
self.apply(_apply)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict, Optional
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.models import FairseqDecoder
from torch import Tensor
logger = logging.getLogger(__name__)
@with_incremental_state
class FairseqIncrementalDecoder(FairseqDecoder):
"""Base class for incremental decoders.
Incremental decoding is a special mode at inference time where the Model
only receives a single timestep of input corresponding to the previous
output token (for teacher forcing) and must produce the next output
*incrementally*. Thus the model must cache any long-term state that is
needed about the sequence, e.g., hidden states, convolutional states, etc.
Compared to the standard :class:`FairseqDecoder` interface, the incremental
decoder interface allows :func:`forward` functions to take an extra keyword
argument (*incremental_state*) that can be used to cache state across
time-steps.
The :class:`FairseqIncrementalDecoder` interface also defines the
:func:`reorder_incremental_state` method, which is used during beam search
to select and reorder the incremental state based on the selection of beams.
To learn more about how incremental decoding works, refer to `this blog
<http://www.telesens.co/2019/04/21/understanding-incremental-decoding-in-fairseq/>`_.
"""
def __init__(self, dictionary):
super().__init__(dictionary)
def forward(
self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs
):
"""
Args:
prev_output_tokens (LongTensor): shifted output tokens of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (dict, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict, optional): dictionary used for storing
state during :ref:`Incremental decoding`
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
raise NotImplementedError
def extract_features(
self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs
):
"""
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
raise NotImplementedError
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Reorder incremental state.
This will be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the selection of beams.
"""
pass
def reorder_incremental_state_scripting(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Main entry point for reordering the incremental state.
Due to limitations in TorchScript, we call this function in
:class:`fairseq.sequence_generator.SequenceGenerator` instead of
calling :func:`reorder_incremental_state` directly.
"""
for module in self.modules():
if hasattr(module, "reorder_incremental_state"):
result = module.reorder_incremental_state(incremental_state, new_order)
if result is not None:
incremental_state = result
def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""
if getattr(self, "_beam_size", -1) != beam_size:
seen = set()
def apply_set_beam_size(module):
if (
module != self
and hasattr(module, "set_beam_size")
and module not in seen
):
seen.add(module)
module.set_beam_size(beam_size)
self.apply(apply_set_beam_size)
self._beam_size = beam_size
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Base classes for various fairseq models.
"""
import logging
from argparse import Namespace
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import Dictionary
from fairseq.dataclass.utils import (
convert_namespace_to_omegaconf,
gen_parser_from_dataclass,
)
from fairseq.models import FairseqDecoder, FairseqEncoder
from omegaconf import DictConfig
from torch import Tensor
logger = logging.getLogger(__name__)
def check_type(module, expected_type):
if hasattr(module, "unwrapped_module"):
assert isinstance(
module.unwrapped_module, expected_type
), f"{type(module.unwrapped_module)} != {expected_type}"
else:
assert isinstance(module, expected_type), f"{type(module)} != {expected_type}"
class BaseFairseqModel(nn.Module):
"""Base class for fairseq models."""
def __init__(self):
super().__init__()
self._is_generation_fast = False
@classmethod
def add_args(cls, parser):
"""Add model-specific arguments to the parser."""
dc = getattr(cls, "__dataclass", None)
if dc is not None:
# do not set defaults so that settings defaults from various architectures still works
gen_parser_from_dataclass(parser, dc(), delete_default=True)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
raise NotImplementedError("Model must implement the build_model method")
def get_targets(self, sample, net_output):
"""Get targets from either the sample or the net's output."""
return sample["target"]
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def get_normalized_probs_scriptable(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Scriptable helper function for get_normalized_probs in ~BaseFairseqModel"""
if hasattr(self, "decoder"):
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
elif torch.is_tensor(net_output):
# syntactic sugar for simple models which don't have a decoder
# (e.g., the classification tutorial)
logits = net_output.float()
if log_probs:
return F.log_softmax(logits, dim=-1)
else:
return F.softmax(logits, dim=-1)
raise NotImplementedError
def extract_features(self, *args, **kwargs):
"""Similar to *forward* but only return features."""
return self(*args, **kwargs)
def max_positions(self):
"""Maximum length supported by the model."""
return None
def load_state_dict(
self,
state_dict,
strict=True,
model_cfg: Optional[DictConfig] = None,
args: Optional[Namespace] = None,
):
"""Copies parameters and buffers from *state_dict* into this module and
its descendants.
Overrides the method in :class:`nn.Module`. Compared with that method
this additionally "upgrades" *state_dicts* from old checkpoints.
"""
if model_cfg is None and args is not None:
logger.warn(
"using 'args' is deprecated, please update your code to use dataclass config"
)
model_cfg = convert_namespace_to_omegaconf(args).model
self.upgrade_state_dict(state_dict)
from fairseq.checkpoint_utils import prune_state_dict
new_state_dict = prune_state_dict(state_dict, model_cfg)
return super().load_state_dict(new_state_dict, strict)
def upgrade_state_dict(self, state_dict):
"""Upgrade old state dicts to work with newer code."""
self.upgrade_state_dict_named(state_dict, "")
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade old state dicts to work with newer code.
Args:
state_dict (dict): state dictionary to upgrade, in place
name (str): the state dict key corresponding to the current module
"""
assert state_dict is not None
def do_upgrade(m, prefix):
if len(prefix) > 0:
prefix += "."
for n, c in m.named_children():
name = prefix + n
if hasattr(c, "upgrade_state_dict_named"):
c.upgrade_state_dict_named(state_dict, name)
elif hasattr(c, "upgrade_state_dict"):
c.upgrade_state_dict(state_dict)
do_upgrade(c, name)
do_upgrade(self, name)
def set_num_updates(self, num_updates):
"""State from trainer to pass along to model at every update."""
for m in self.modules():
if hasattr(m, "set_num_updates") and m != self:
m.set_num_updates(num_updates)
def prepare_for_inference_(self, cfg: DictConfig):
"""Prepare model for inference."""
kwargs = {}
kwargs["beamable_mm_beam_size"] = (
None
if getattr(cfg.generation, "no_beamable_mm", False)
else getattr(cfg.generation, "beam", 5)
)
kwargs["need_attn"] = getattr(cfg.generation, "print_alignment", False)
if getattr(cfg.generation, "retain_dropout", False):
kwargs["retain_dropout"] = cfg.generation.retain_dropout
kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules
self.make_generation_fast_(**kwargs)
def make_generation_fast_(self, **kwargs):
"""
Legacy entry point to optimize model for faster generation.
Prefer prepare_for_inference_.
"""
if self._is_generation_fast:
return # only apply once
self._is_generation_fast = True
# remove weight norm from all modules in the network
def apply_remove_weight_norm(module):
try:
nn.utils.remove_weight_norm(module)
except (AttributeError, ValueError): # this module didn't have weight norm
return
self.apply(apply_remove_weight_norm)
def apply_make_generation_fast_(module, prefix):
if len(prefix) > 0:
prefix += "."
base_func = BaseFairseqModel.make_generation_fast_
for n, m in module.named_modules():
if (
m != self
and hasattr(m, "make_generation_fast_")
# don't call this implementation again, e.g., if
# children modules also inherit from BaseFairseqModel
and m.make_generation_fast_.__func__ is not base_func
):
name = prefix + n
m.make_generation_fast_(name=name, **kwargs)
apply_make_generation_fast_(self, "")
def train(mode=True):
if mode:
raise RuntimeError("cannot train after make_generation_fast")
# this model should no longer be used for training
self.eval()
self.train = train
def prepare_for_onnx_export_(self, **kwargs):
"""Make model exportable via ONNX trace."""
seen = set()
def apply_prepare_for_onnx_export_(module):
if (
module != self
and hasattr(module, "prepare_for_onnx_export_")
and module not in seen
):
seen.add(module)
module.prepare_for_onnx_export_(**kwargs)
self.apply(apply_prepare_for_onnx_export_)
@classmethod
def from_pretrained(
cls,
model_name_or_path,
checkpoint_file="model.pt",
data_name_or_path=".",
**kwargs,
):
"""
Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model
file. Downloads and caches the pre-trained model file if needed.
The base implementation returns a
:class:`~fairseq.hub_utils.GeneratorHubInterface`, which can be used to
generate translations or sample from language models. The underlying
:class:`~fairseq.models.FairseqModel` can be accessed via the
*generator.models* attribute.
Other models may override this to implement custom hub interfaces.
Args:
model_name_or_path (str): either the name of a pre-trained model to
load or a path/URL to a pre-trained model state dict
checkpoint_file (str, optional): colon-separated list of checkpoint
files in the model archive to ensemble (default: 'model.pt')
data_name_or_path (str, optional): point args.data to the archive
at the given path/URL. Can start with '.' or './' to reuse the
model archive path.
"""
from fairseq import hub_utils
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
**kwargs,
)
logger.info(x["args"])
return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"])
@classmethod
def hub_models(cls):
return {}
class FairseqEncoderDecoderModel(BaseFairseqModel):
"""Base class for encoder-decoder models.
Args:
encoder (FairseqEncoder): the encoder
decoder (FairseqDecoder): the decoder
"""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
check_type(self.encoder, FairseqEncoder)
check_type(self.decoder, FairseqDecoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
"""
Run the forward pass for an encoder-decoder model.
First feed a batch of source tokens through the encoder. Then, feed the
encoder output and previous decoder outputs (i.e., teacher forcing) to
the decoder to produce the next outputs::
encoder_out = self.encoder(src_tokens, src_lengths)
return self.decoder(prev_output_tokens, encoder_out)
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(
prev_output_tokens, encoder_out=encoder_out, **kwargs
)
return decoder_out
def forward_decoder(self, prev_output_tokens, **kwargs):
return self.decoder(prev_output_tokens, **kwargs)
def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
"""
Similar to *forward* but only return features.
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
features = self.decoder.extract_features(
prev_output_tokens, encoder_out=encoder_out, **kwargs
)
return features
def output_layer(self, features, **kwargs):
"""Project features to the default output size (typically vocabulary size)."""
return self.decoder.output_layer(features, **kwargs)
def max_positions(self):
"""Maximum length supported by the model."""
return (self.encoder.max_positions(), self.decoder.max_positions())
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder.max_positions()
class FairseqModel(FairseqEncoderDecoderModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
utils.deprecation_warning(
"FairseqModel is deprecated, please use FairseqEncoderDecoderModel "
"or BaseFairseqModel instead",
stacklevel=4,
)
class FairseqMultiModel(BaseFairseqModel):
"""Base class for combining multiple encoder-decoder models."""
def __init__(self, encoders, decoders):
super().__init__()
assert encoders.keys() == decoders.keys()
self.keys = list(encoders.keys())
for key in self.keys:
check_type(encoders[key], FairseqEncoder)
check_type(decoders[key], FairseqDecoder)
self.models = nn.ModuleDict(
{
key: FairseqEncoderDecoderModel(encoders[key], decoders[key])
for key in self.keys
}
)
@staticmethod
def build_shared_embeddings(
dicts: Dict[str, Dictionary],
langs: List[str],
embed_dim: int,
build_embedding: callable,
pretrained_embed_path: Optional[str] = None,
):
"""
Helper function to build shared embeddings for a set of languages after
checking that all dicts corresponding to those languages are equivalent.
Args:
dicts: Dict of lang_id to its corresponding Dictionary
langs: languages that we want to share embeddings for
embed_dim: embedding dimension
build_embedding: callable function to actually build the embedding
pretrained_embed_path: Optional path to load pretrained embeddings
"""
shared_dict = dicts[langs[0]]
if any(dicts[lang] != shared_dict for lang in langs):
raise ValueError(
"--share-*-embeddings requires a joined dictionary: "
"--share-encoder-embeddings requires a joined source "
"dictionary, --share-decoder-embeddings requires a joined "
"target dictionary, and --share-all-embeddings requires a "
"joint source + target dictionary."
)
return build_embedding(shared_dict, embed_dim, pretrained_embed_path)
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
raise NotImplementedError
def max_positions(self):
"""Maximum length supported by the model."""
return {
key: (
self.models[key].encoder.max_positions(),
self.models[key].decoder.max_positions(),
)
for key in self.keys
}
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return min(model.decoder.max_positions() for model in self.models.values())
@property
def encoder(self):
return self.models[self.keys[0]].encoder
@property
def decoder(self):
return self.models[self.keys[0]].decoder
def forward_decoder(self, prev_output_tokens, **kwargs):
return self.decoder(prev_output_tokens, **kwargs)
def load_state_dict(
self,
state_dict,
strict=True,
model_cfg=None,
args: Optional[Namespace] = None,
):
"""Copies parameters and buffers from *state_dict* into this module and
its descendants.
Overrides the method in :class:`nn.Module`. Compared with that method
this additionally "upgrades" *state_dicts* from old checkpoints.
"""
if model_cfg is None and args is not None:
logger.warn(
"using 'args' is deprecated, please update your code to use dataclass config"
)
model_cfg = convert_namespace_to_omegaconf(args).model
self.upgrade_state_dict(state_dict)
from fairseq.checkpoint_utils import prune_state_dict
new_state_dict = prune_state_dict(state_dict, model_cfg)
return super().load_state_dict(new_state_dict, strict)
class FairseqLanguageModel(BaseFairseqModel):
"""Base class for decoder-only models.
Args:
decoder (FairseqDecoder): the decoder
"""
def __init__(self, decoder):
super().__init__()
self.decoder = decoder
check_type(self.decoder, FairseqDecoder)
def forward(self, src_tokens, **kwargs):
"""
Run the forward pass for a decoder-only model.
Feeds a batch of tokens through the decoder to predict the next tokens.
Args:
src_tokens (LongTensor): tokens on which to condition the decoder,
of shape `(batch, tgt_len)`
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
Returns:
tuple:
- the decoder's output of shape `(batch, seq_len, vocab)`
- a dictionary with any model-specific outputs
"""
return self.decoder(src_tokens, **kwargs)
def forward_decoder(self, prev_output_tokens, **kwargs):
return self.decoder(prev_output_tokens, **kwargs)
def extract_features(self, src_tokens, **kwargs):
"""
Similar to *forward* but only return features.
Returns:
tuple:
- the decoder's features of shape `(batch, seq_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
return self.decoder.extract_features(src_tokens, **kwargs)
def output_layer(self, features, **kwargs):
"""Project features to the default output size (typically vocabulary size)."""
return self.decoder.output_layer(features, **kwargs)
def max_positions(self):
"""Maximum length supported by the model."""
return self.decoder.max_positions()
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder.max_positions()
@property
def supported_targets(self):
return {"future"}
class FairseqEncoderModel(BaseFairseqModel):
"""Base class for encoder-only models.
Args:
encoder (FairseqEncoder): the encoder
"""
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
check_type(self.encoder, FairseqEncoder)
def forward(self, src_tokens, src_lengths, **kwargs):
"""
Run the forward pass for a encoder-only model.
Feeds a batch of tokens through the encoder to generate features.
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
Returns:
the encoder's output, typically of shape `(batch, src_len, features)`
"""
return self.encoder(src_tokens, src_lengths, **kwargs)
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
encoder_out = net_output["encoder_out"]
if torch.is_tensor(encoder_out):
logits = encoder_out.float()
if log_probs:
return F.log_softmax(logits, dim=-1)
else:
return F.softmax(logits, dim=-1)
raise NotImplementedError
def max_positions(self):
"""Maximum length supported by the model."""
return self.encoder.max_positions()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.modules import (
AdaptiveSoftmax,
BeamableMM,
FairseqDropout,
GradMultiply,
LearnedPositionalEmbedding,
LinearizedConvolution,
)
@register_model("fconv")
class FConvModel(FairseqEncoderDecoderModel):
"""
A fully convolutional model, i.e. a convolutional encoder and a
convolutional decoder, as described in `"Convolutional Sequence to Sequence
Learning" (Gehring et al., 2017) <https://arxiv.org/abs/1705.03122>`_.
Args:
encoder (FConvEncoder): the encoder
decoder (FConvDecoder): the decoder
The Convolutional model provides the following named architectures and
command-line arguments:
.. argparse::
:ref: fairseq.models.fconv_parser
:prog:
"""
@classmethod
def hub_models(cls):
def moses_subword(path):
return {
"path": path,
"tokenizer": "moses",
"bpe": "subword_nmt",
}
return {
"conv.wmt14.en-fr": moses_subword(
"https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2"
),
"conv.wmt14.en-de": moses_subword(
"https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2"
),
"conv.wmt17.en-de": moses_subword(
"https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2"
),
}
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
self.encoder.num_attention_layers = sum(
layer is not None for layer in decoder.attention
)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-layers', type=str, metavar='EXPR',
help='encoder layers [(dim, kernel_size), ...]')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-layers', type=str, metavar='EXPR',
help='decoder layers [(dim, kernel_size), ...]')
parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
help='decoder output embedding dimension')
parser.add_argument('--decoder-attention', type=str, metavar='EXPR',
help='decoder attention [True, ...]')
parser.add_argument('--share-input-output-embed', action='store_true',
help='share input and output embeddings (requires'
' --decoder-out-embed-dim and --decoder-embed-dim'
' to be equal)')
# fmt: on
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure that all args are properly defaulted (in case there are any new ones)
base_architecture(args)
encoder_embed_dict = None
if args.encoder_embed_path:
encoder_embed_dict = utils.parse_embedding(args.encoder_embed_path)
utils.print_embed_overlap(encoder_embed_dict, task.source_dictionary)
decoder_embed_dict = None
if args.decoder_embed_path:
decoder_embed_dict = utils.parse_embedding(args.decoder_embed_path)
utils.print_embed_overlap(decoder_embed_dict, task.target_dictionary)
encoder = FConvEncoder(
dictionary=task.source_dictionary,
embed_dim=args.encoder_embed_dim,
embed_dict=encoder_embed_dict,
convolutions=eval(args.encoder_layers),
dropout=args.dropout,
max_positions=args.max_source_positions,
)
decoder = FConvDecoder(
dictionary=task.target_dictionary,
embed_dim=args.decoder_embed_dim,
embed_dict=decoder_embed_dict,
convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_out_embed_dim,
attention=eval(args.decoder_attention),
dropout=args.dropout,
max_positions=args.max_target_positions,
share_embed=args.share_input_output_embed,
)
return FConvModel(encoder, decoder)
class FConvEncoder(FairseqEncoder):
"""
Convolutional encoder consisting of `len(convolutions)` layers.
Args:
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_dim (int, optional): embedding dimension
embed_dict (str, optional): filename from which to load pre-trained
embeddings
max_positions (int, optional): maximum supported input sequence length
convolutions (list, optional): the convolutional layer structure. Each
list item `i` corresponds to convolutional layer `i`. Layers are
given as ``(out_channels, kernel_width, [residual])``. Residual
connections are added between layers when ``residual=1`` (which is
the default behavior).
dropout (float, optional): dropout to be applied before each conv layer
"""
def __init__(
self,
dictionary,
embed_dim=512,
embed_dict=None,
max_positions=1024,
convolutions=((512, 3),) * 20,
dropout=0.1,
):
super().__init__(dictionary)
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.num_attention_layers = None
num_embeddings = len(dictionary)
self.padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
if embed_dict:
self.embed_tokens = utils.load_embedding(
embed_dict, self.dictionary, self.embed_tokens
)
self.embed_positions = PositionalEmbedding(
max_positions,
embed_dim,
self.padding_idx,
)
convolutions = extend_conv_spec(convolutions)
in_channels = convolutions[0][0]
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
self.projections = nn.ModuleList()
self.convolutions = nn.ModuleList()
self.residuals = []
layer_in_channels = [in_channels]
for _, (out_channels, kernel_size, residual) in enumerate(convolutions):
if residual == 0:
residual_dim = out_channels
else:
residual_dim = layer_in_channels[-residual]
self.projections.append(
Linear(residual_dim, out_channels)
if residual_dim != out_channels
else None
)
if kernel_size % 2 == 1:
padding = kernel_size // 2
else:
padding = 0
self.convolutions.append(
ConvTBC(
in_channels,
out_channels * 2,
kernel_size,
dropout=dropout,
padding=padding,
)
)
self.residuals.append(residual)
in_channels = out_channels
layer_in_channels.append(out_channels)
self.fc2 = Linear(in_channels, embed_dim)
def forward(self, src_tokens, src_lengths):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (LongTensor): lengths of each source sentence of shape
`(batch)`
Returns:
dict:
- **encoder_out** (tuple): a tuple with two elements, where the
first element is the last encoder layer's output and the
second element is the same quantity summed with the input
embedding (used for attention). The shape of both tensors is
`(batch, src_len, embed_dim)`.
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
"""
# embed tokens and positions
x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens)
x = self.dropout_module(x)
input_embedding = x
# project to size of convolution
x = self.fc1(x)
# used to mask padding in input
encoder_padding_mask = src_tokens.eq(self.padding_idx).t() # -> T x B
if not encoder_padding_mask.any():
encoder_padding_mask = None
# B x T x C -> T x B x C
x = x.transpose(0, 1)
residuals = [x]
# temporal convolutions
for proj, conv, res_layer in zip(
self.projections, self.convolutions, self.residuals
):
if res_layer > 0:
residual = residuals[-res_layer]
residual = residual if proj is None else proj(residual)
else:
residual = None
if encoder_padding_mask is not None:
x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0)
x = self.dropout_module(x)
if conv.kernel_size[0] % 2 == 1:
# padding is implicit in the conv
x = conv(x)
else:
padding_l = (conv.kernel_size[0] - 1) // 2
padding_r = conv.kernel_size[0] // 2
x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r))
x = conv(x)
x = F.glu(x, dim=2)
if residual is not None:
x = (x + residual) * math.sqrt(0.5)
residuals.append(x)
# T x B x C -> B x T x C
x = x.transpose(1, 0)
# project back to size of embedding
x = self.fc2(x)
if encoder_padding_mask is not None:
encoder_padding_mask = encoder_padding_mask.t() # -> B x T
x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0)
# scale gradients (this only affects backward, not forward)
x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers))
# add output to input embedding for attention
y = (x + input_embedding) * math.sqrt(0.5)
return {
"encoder_out": (x, y),
"encoder_padding_mask": encoder_padding_mask, # B x T
}
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out["encoder_out"] is not None:
encoder_out["encoder_out"] = (
encoder_out["encoder_out"][0].index_select(0, new_order),
encoder_out["encoder_out"][1].index_select(0, new_order),
)
if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(0, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
return self.embed_positions.max_positions
class AttentionLayer(nn.Module):
def __init__(self, conv_channels, embed_dim, bmm=None):
super().__init__()
# projects from output of convolution to embedding dimension
self.in_projection = Linear(conv_channels, embed_dim)
# projects from embedding dimension to convolution size
self.out_projection = Linear(embed_dim, conv_channels)
self.bmm = bmm if bmm is not None else torch.bmm
def forward(self, x, target_embedding, encoder_out, encoder_padding_mask):
residual = x
# attention
x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5)
x = self.bmm(x, encoder_out[0])
# don't attend over padding
if encoder_padding_mask is not None:
x = (
x.float()
.masked_fill(encoder_padding_mask.unsqueeze(1), float("-inf"))
.type_as(x)
) # FP16 support: cast to float and back
# softmax over last dim
sz = x.size()
x = F.softmax(x.view(sz[0] * sz[1], sz[2]), dim=1)
x = x.view(sz)
attn_scores = x
x = self.bmm(x, encoder_out[1])
# scale attention output (respecting potentially different lengths)
s = encoder_out[1].size(1)
if encoder_padding_mask is None:
x = x * (s * math.sqrt(1.0 / s))
else:
s = s - encoder_padding_mask.type_as(x).sum(
dim=1, keepdim=True
) # exclude padding
s = s.unsqueeze(-1)
x = x * (s * s.rsqrt())
# project back
x = (self.out_projection(x) + residual) * math.sqrt(0.5)
return x, attn_scores
def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs):
"""Replace torch.bmm with BeamableMM."""
if beamable_mm_beam_size is not None:
del self.bmm
self.add_module("bmm", BeamableMM(beamable_mm_beam_size))
class FConvDecoder(FairseqIncrementalDecoder):
"""Convolutional decoder"""
def __init__(
self,
dictionary,
embed_dim=512,
embed_dict=None,
out_embed_dim=256,
max_positions=1024,
convolutions=((512, 3),) * 20,
attention=True,
dropout=0.1,
share_embed=False,
positional_embeddings=True,
adaptive_softmax_cutoff=None,
adaptive_softmax_dropout=0.0,
):
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([2]))
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.need_attn = True
convolutions = extend_conv_spec(convolutions)
in_channels = convolutions[0][0]
if isinstance(attention, bool):
# expand True into [True, True, ...] and do the same with False
attention = [attention] * len(convolutions)
if not isinstance(attention, list) or len(attention) != len(convolutions):
raise ValueError(
"Attention is expected to be a list of booleans of "
"length equal to the number of layers."
)
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
if embed_dict:
self.embed_tokens = utils.load_embedding(
embed_dict, self.dictionary, self.embed_tokens
)
self.embed_positions = (
PositionalEmbedding(
max_positions,
embed_dim,
padding_idx,
)
if positional_embeddings
else None
)
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
self.projections = nn.ModuleList()
self.convolutions = nn.ModuleList()
self.attention = nn.ModuleList()
self.residuals = []
layer_in_channels = [in_channels]
for i, (out_channels, kernel_size, residual) in enumerate(convolutions):
if residual == 0:
residual_dim = out_channels
else:
residual_dim = layer_in_channels[-residual]
self.projections.append(
Linear(residual_dim, out_channels)
if residual_dim != out_channels
else None
)
self.convolutions.append(
LinearizedConv1d(
in_channels,
out_channels * 2,
kernel_size,
padding=(kernel_size - 1),
dropout=dropout,
)
)
self.attention.append(
AttentionLayer(out_channels, embed_dim) if attention[i] else None
)
self.residuals.append(residual)
in_channels = out_channels
layer_in_channels.append(out_channels)
self.adaptive_softmax = None
self.fc2 = self.fc3 = None
if adaptive_softmax_cutoff is not None:
assert not share_embed
self.adaptive_softmax = AdaptiveSoftmax(
num_embeddings,
in_channels,
adaptive_softmax_cutoff,
dropout=adaptive_softmax_dropout,
)
else:
self.fc2 = Linear(in_channels, out_embed_dim)
if share_embed:
assert out_embed_dim == embed_dim, (
"Shared embed weights implies same dimensions "
" out_embed_dim={} vs embed_dim={}".format(out_embed_dim, embed_dim)
)
self.fc3 = nn.Linear(out_embed_dim, num_embeddings)
self.fc3.weight = self.embed_tokens.weight
else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(
self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
):
if encoder_out is not None:
encoder_padding_mask = encoder_out["encoder_padding_mask"]
encoder_out = encoder_out["encoder_out"]
# split and transpose encoder outputs
encoder_a, encoder_b = self._split_encoder_out(
encoder_out, incremental_state
)
if self.embed_positions is not None:
pos_embed = self.embed_positions(prev_output_tokens, incremental_state)
else:
pos_embed = 0
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
x = self._embed_tokens(prev_output_tokens, incremental_state)
# embed tokens and combine with positional embeddings
x += pos_embed
x = self.dropout_module(x)
target_embedding = x
# project to size of convolution
x = self.fc1(x)
# B x T x C -> T x B x C
x = self._transpose_if_training(x, incremental_state)
# temporal convolutions
avg_attn_scores = None
num_attn_layers = len(self.attention)
residuals = [x]
for proj, conv, attention, res_layer in zip(
self.projections, self.convolutions, self.attention, self.residuals
):
if res_layer > 0:
residual = residuals[-res_layer]
residual = residual if proj is None else proj(residual)
else:
residual = None
x = self.dropout_module(x)
x = conv(x, incremental_state)
x = F.glu(x, dim=2)
# attention
if attention is not None:
x = self._transpose_if_training(x, incremental_state)
x, attn_scores = attention(
x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask
)
if not self.training and self.need_attn:
attn_scores = attn_scores / num_attn_layers
if avg_attn_scores is None:
avg_attn_scores = attn_scores
else:
avg_attn_scores.add_(attn_scores)
x = self._transpose_if_training(x, incremental_state)
# residual
if residual is not None:
x = (x + residual) * math.sqrt(0.5)
residuals.append(x)
# T x B x C -> B x T x C
x = self._transpose_if_training(x, incremental_state)
# project back to size of vocabulary if not using adaptive softmax
if self.fc2 is not None and self.fc3 is not None:
x = self.fc2(x)
x = self.dropout_module(x)
x = self.fc3(x)
return x, avg_attn_scores
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
encoder_out = utils.get_incremental_state(
self, incremental_state, "encoder_out"
)
if encoder_out is not None:
encoder_out = tuple(eo.index_select(0, new_order) for eo in encoder_out)
utils.set_incremental_state(
self, incremental_state, "encoder_out", encoder_out
)
def max_positions(self):
"""Maximum output length supported by the decoder."""
return (
self.embed_positions.max_positions
if self.embed_positions is not None
else float("inf")
)
def upgrade_state_dict(self, state_dict):
if utils.item(state_dict.get("decoder.version", torch.Tensor([1]))[0]) < 2:
# old models use incorrect weight norm dimension
for i, conv in enumerate(self.convolutions):
# reconfigure weight norm
nn.utils.remove_weight_norm(conv)
self.convolutions[i] = nn.utils.weight_norm(conv, dim=0)
state_dict["decoder.version"] = torch.Tensor([1])
return state_dict
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def _embed_tokens(self, tokens, incremental_state):
if incremental_state is not None:
# keep only the last token for incremental forward pass
tokens = tokens[:, -1:]
return self.embed_tokens(tokens)
def _split_encoder_out(self, encoder_out, incremental_state):
"""Split and transpose encoder outputs.
This is cached when doing incremental inference.
"""
cached_result = utils.get_incremental_state(
self, incremental_state, "encoder_out"
)
if cached_result is not None:
return cached_result
# transpose only once to speed up attention layers
encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(1, 2).contiguous()
result = (encoder_a, encoder_b)
if incremental_state is not None:
utils.set_incremental_state(self, incremental_state, "encoder_out", result)
return result
def _transpose_if_training(self, x, incremental_state):
if incremental_state is None:
x = x.transpose(0, 1)
return x
def extend_conv_spec(convolutions):
"""
Extends convolutional spec that is a list of tuples of 2 or 3 parameters
(kernel size, dim size and optionally how many layers behind to look for residual)
to default the residual propagation param if it is not specified
"""
extended = []
for spec in convolutions:
if len(spec) == 3:
extended.append(spec)
elif len(spec) == 2:
extended.append(spec + (1,))
else:
raise Exception(
"invalid number of parameters in convolution spec "
+ str(spec)
+ ". expected 2 or 3"
)
return tuple(extended)
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, 0, 0.1)
nn.init.constant_(m.weight[padding_idx], 0)
return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx):
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
nn.init.normal_(m.weight, 0, 0.1)
nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, dropout=0.0):
"""Weight-normalized Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features)
nn.init.normal_(m.weight, mean=0, std=math.sqrt((1 - dropout) / in_features))
nn.init.constant_(m.bias, 0)
return nn.utils.weight_norm(m)
def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs):
"""Weight-normalized Conv1d layer optimized for decoding"""
m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
nn.init.normal_(m.weight, mean=0, std=std)
nn.init.constant_(m.bias, 0)
return nn.utils.weight_norm(m, dim=2)
def ConvTBC(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs):
"""Weight-normalized Conv1d layer"""
from fairseq.modules import ConvTBC
m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
nn.init.normal_(m.weight, mean=0, std=std)
nn.init.constant_(m.bias, 0)
return nn.utils.weight_norm(m, dim=2)
@register_model_architecture("fconv", "fconv")
def base_architecture(args):
args.dropout = getattr(args, "dropout", 0.1)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_layers = getattr(args, "encoder_layers", "[(512, 3)] * 20")
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_layers = getattr(args, "decoder_layers", "[(512, 3)] * 20")
args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256)
args.decoder_attention = getattr(args, "decoder_attention", "True")
args.share_input_output_embed = getattr(args, "share_input_output_embed", False)
@register_model_architecture("fconv", "fconv_iwslt_de_en")
def fconv_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_layers = getattr(args, "encoder_layers", "[(256, 3)] * 4")
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
args.decoder_layers = getattr(args, "decoder_layers", "[(256, 3)] * 3")
args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256)
base_architecture(args)
@register_model_architecture("fconv", "fconv_wmt_en_ro")
def fconv_wmt_en_ro(args):
args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512)
base_architecture(args)
@register_model_architecture("fconv", "fconv_wmt_en_de")
def fconv_wmt_en_de(args):
convs = "[(512, 3)] * 9" # first 9 layers have 512 units
convs += " + [(1024, 3)] * 4" # next 4 layers have 1024 units
convs += " + [(2048, 1)] * 2" # final 2 layers use 1x1 convolutions
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_layers = getattr(args, "encoder_layers", convs)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768)
args.decoder_layers = getattr(args, "decoder_layers", convs)
args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512)
base_architecture(args)
@register_model_architecture("fconv", "fconv_wmt_en_fr")
def fconv_wmt_en_fr(args):
convs = "[(512, 3)] * 6" # first 6 layers have 512 units
convs += " + [(768, 3)] * 4" # next 4 layers have 768 units
convs += " + [(1024, 3)] * 3" # next 3 layers have 1024 units
convs += " + [(2048, 1)] * 1" # next 1 layer uses 1x1 convolutions
convs += " + [(4096, 1)] * 1" # final 1 layer uses 1x1 convolutions
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_layers = getattr(args, "encoder_layers", convs)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768)
args.decoder_layers = getattr(args, "decoder_layers", convs)
args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512)
base_architecture(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq import utils
from fairseq.models import (
FairseqLanguageModel,
register_model,
register_model_architecture,
)
from fairseq.models.fconv import FConvDecoder
from fairseq.utils import safe_hasattr
@register_model("fconv_lm")
class FConvLanguageModel(FairseqLanguageModel):
def __init__(self, decoder):
super().__init__(decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-layers",
type=str,
metavar="EXPR",
help="decoder layers [(dim, kernel_size), ...]",
)
parser.add_argument(
"--decoder-out-embed-dim",
type=int,
metavar="N",
help="decoder output embedding dimension",
)
parser.add_argument(
"--adaptive-softmax-cutoff",
metavar="EXPR",
help="comma separated list of adaptive softmax cutoff points. "
"Must be used with adaptive_loss criterion",
)
parser.add_argument(
"--adaptive-softmax-dropout",
type=float,
metavar="D",
help="sets adaptive softmax dropout for the tail projections",
)
parser.add_argument(
"--decoder-attention",
type=str,
metavar="EXPR",
help="decoder attention [True, ...]",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_lm_architecture(args)
if safe_hasattr(args, "max_target_positions") and not safe_hasattr(
args, "tokens_per_sample"
):
args.tokens_per_sample = args.max_target_positions
decoder = FConvDecoder(
dictionary=task.target_dictionary,
embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_embed_dim,
attention=eval(args.decoder_attention),
dropout=args.dropout,
max_positions=args.tokens_per_sample,
share_embed=False,
positional_embeddings=False,
adaptive_softmax_cutoff=(
utils.eval_str_list(args.adaptive_softmax_cutoff, type=int)
if args.criterion == "adaptive_loss"
else None
),
adaptive_softmax_dropout=args.adaptive_softmax_dropout,
)
return FConvLanguageModel(decoder)
@register_model_architecture("fconv_lm", "fconv_lm")
def base_lm_architecture(args):
args.dropout = getattr(args, "dropout", 0.1)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128)
args.decoder_layers = getattr(args, "decoder_layers", "[(1268, 4)] * 13")
args.decoder_attention = getattr(args, "decoder_attention", "False")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
@register_model_architecture("fconv_lm", "fconv_lm_dauphin_wikitext103")
def fconv_lm_dauphin_wikitext103(args):
layers = "[(850, 6)] * 3"
layers += " + [(850, 1)] * 1"
layers += " + [(850, 5)] * 4"
layers += " + [(850, 1)] * 1"
layers += " + [(850, 4)] * 3"
layers += " + [(1024, 4)] * 1"
layers += " + [(2048, 4)] * 1"
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 280)
args.decoder_layers = getattr(args, "decoder_layers", layers)
args.decoder_attention = getattr(args, "decoder_attention", "False")
args.adaptive_softmax_cutoff = getattr(
args, "adaptive_softmax_cutoff", "10000,20000,200000"
)
base_lm_architecture(args)
@register_model_architecture("fconv_lm", "fconv_lm_dauphin_gbw")
def fconv_lm_dauphin_gbw(args):
layers = "[(512, 5)]"
layers += " + [(128, 1, 0), (128, 5, 0), (512, 1, 3)] * 3"
layers += " + [(512, 1, 0), (512, 5, 0), (1024, 1, 3)] * 3"
layers += " + [(1024, 1, 0), (1024, 5, 0), (2048, 1, 3)] * 6"
layers += " + [(1024, 1, 0), (1024, 5, 0), (4096, 1, 3)]"
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128)
args.decoder_layers = getattr(args, "decoder_layers", layers)
args.decoder_attention = getattr(args, "decoder_attention", "False")
args.adaptive_softmax_cutoff = getattr(
args, "adaptive_softmax_cutoff", "10000,50000,200000"
)
base_lm_architecture(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.models import (
CompositeEncoder,
FairseqDecoder,
FairseqEncoder,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.modules import (
DownsampledMultiHeadAttention,
FairseqDropout,
GradMultiply,
LayerNorm,
LearnedPositionalEmbedding,
LinearizedConvolution,
)
logger = logging.getLogger(__name__)
@register_model("fconv_self_att")
class FConvModelSelfAtt(FairseqEncoderDecoderModel):
@classmethod
def hub_models(cls):
return {
"conv.stories.pretrained": {
"path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz",
"checkpoint_file": "pretrained_checkpoint.pt",
"tokenizer": "nltk",
},
"conv.stories": {
"path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz",
"checkpoint_file": "fusion_checkpoint.pt",
"tokenizer": "nltk",
"pretrained": "True",
"pretrained_checkpoint": "./pretrained_checkpoint.pt",
},
# Test set containing dictionaries
"data.stories": "https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2",
}
def __init__(self, encoder, decoder, pretrained_encoder=None):
super().__init__(encoder, decoder)
self.encoder.num_attention_layers = sum(
layer is not None for layer in decoder.attention
)
self.pretrained_encoder = pretrained_encoder
if self.pretrained_encoder is None:
encoders = {"encoder": encoder}
else:
encoders = {"encoder": encoder, "pretrained": self.pretrained_encoder}
# for fusion model, CompositeEncoder contains both pretrained and training encoders
# these are forwarded and then combined in the decoder
self.encoder = CompositeEncoder(encoders)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-layers', type=str, metavar='EXPR',
help='encoder layers [(dim, kernel_size), ...]')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-layers', type=str, metavar='EXPR',
help='decoder layers [(dim, kernel_size), ...]')
parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
help='decoder output embedding dimension')
parser.add_argument('--decoder-attention', type=str, metavar='EXPR',
help='decoder attention [True, ...]')
parser.add_argument('--self-attention', type=str, metavar='EXPR',
help='decoder self-attention layers, ex: [True] + [False]*5')
parser.add_argument('--multihead-attention-nheads', type=int,
help='Number of heads to use in attention')
parser.add_argument('--multihead-self-attention-nheads', type=int,
help='Number of heads to use in self-attention')
parser.add_argument('--encoder-attention', type=str, metavar='EXPR',
help='encoder attention [True, ...]')
parser.add_argument('--encoder-attention-nheads', type=int,
help='Number of heads to use in encoder attention')
parser.add_argument('--project-input', type=str, metavar='EXPR',
help='Use projections in self-attention [True, ...]')
parser.add_argument('--gated-attention', type=str, metavar='EXPR',
help='Use GLU layers in self-attention projections [True, ...]')
parser.add_argument('--downsample', type=str, metavar='EXPR',
help='Use downsampling in self-attention [True, ...]')
parser.add_argument('--pretrained-checkpoint', metavar='DIR',
help='path to load checkpoint from pretrained model')
parser.add_argument('--pretrained', type=str, metavar='EXPR',
help='use pretrained model when training [True, ...]')
# fmt: on
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
trained_encoder, trained_decoder = None, None
pretrained = eval(args.pretrained)
if pretrained:
logger.info("loading pretrained model")
if not os.path.exists(args.pretrained_checkpoint):
new_pretrained_checkpoint = os.path.join(
args.data, args.pretrained_checkpoint
)
if os.path.exists(new_pretrained_checkpoint):
args.pretrained_checkpoint = new_pretrained_checkpoint
trained_model = checkpoint_utils.load_model_ensemble(
filenames=[args.pretrained_checkpoint],
task=task,
)[0][0]
trained_decoder = list(trained_model.children())[1]
trained_encoder = list(trained_model.children())[0]
# freeze pretrained model
for param in trained_decoder.parameters():
param.requires_grad = False
for param in trained_encoder.parameters():
param.requires_grad = False
encoder = FConvEncoder(
task.source_dictionary,
embed_dim=args.encoder_embed_dim,
convolutions=eval(args.encoder_layers),
dropout=args.dropout,
max_positions=args.max_source_positions,
attention=eval(args.encoder_attention),
attention_nheads=args.encoder_attention_nheads,
)
decoder = FConvDecoder(
task.target_dictionary,
embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_out_embed_dim,
attention=eval(args.decoder_attention),
dropout=args.dropout,
max_positions=args.max_target_positions,
selfattention=eval(args.self_attention),
attention_nheads=args.multihead_attention_nheads,
selfattention_nheads=args.multihead_self_attention_nheads,
project_input=eval(args.project_input),
gated_attention=eval(args.gated_attention),
downsample=eval(args.downsample),
pretrained=pretrained,
trained_decoder=trained_decoder,
)
model = FConvModelSelfAtt(encoder, decoder, trained_encoder)
return model
@property
def pretrained(self):
return self.pretrained_encoder is not None
class FConvEncoder(FairseqEncoder):
"""Convolutional encoder"""
def __init__(
self,
dictionary,
embed_dim=512,
max_positions=1024,
convolutions=((512, 3),) * 20,
dropout=0.1,
attention=False,
attention_nheads=1,
):
super().__init__(dictionary)
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.num_attention_layers = None
num_embeddings = len(dictionary)
self.padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
self.embed_positions = PositionalEmbedding(
max_positions,
embed_dim,
self.padding_idx,
)
def expand_bool_array(val):
if isinstance(val, bool):
# expand True into [True, True, ...] and do the same with False
return [val] * len(convolutions)
return val
attention = expand_bool_array(attention)
in_channels = convolutions[0][0]
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
self.projections = nn.ModuleList()
self.convolutions = nn.ModuleList()
self.attention = nn.ModuleList()
self.attproj = nn.ModuleList()
for i, (out_channels, kernel_size) in enumerate(convolutions):
self.projections.append(
Linear(in_channels, out_channels)
if in_channels != out_channels
else None
)
self.convolutions.append(
ConvTBC(in_channels, out_channels * 2, kernel_size, dropout=dropout)
)
self.attention.append(
SelfAttention(out_channels, embed_dim, attention_nheads)
if attention[i]
else None
)
in_channels = out_channels
self.fc2 = Linear(in_channels, embed_dim)
def forward(self, src_tokens, src_lengths):
# embed tokens and positions
x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens)
x = self.dropout_module(x)
input_embedding = x.transpose(0, 1)
# project to size of convolution
x = self.fc1(x)
encoder_padding_mask = src_tokens.eq(self.padding_idx).t() # -> T x B
if not encoder_padding_mask.any():
encoder_padding_mask = None
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# temporal convolutions
for proj, conv, attention in zip(
self.projections, self.convolutions, self.attention
):
residual = x if proj is None else proj(x)
if encoder_padding_mask is not None:
x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0)
x = self.dropout_module(x)
padding_l = (conv.kernel_size[0] - 1) // 2
padding_r = conv.kernel_size[0] // 2
x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r))
x = conv(x)
x = F.glu(x, dim=2)
if attention is not None:
x = attention(x)
x = (x + residual) * math.sqrt(0.5)
# T x B x C -> B x T x C
x = x.transpose(1, 0)
# project back to size of embedding
x = self.fc2(x)
if encoder_padding_mask is not None:
encoder_padding_mask = encoder_padding_mask.t() # -> B x T
x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0)
# scale gradients (this only affects backward, not forward)
x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers))
# add output to input embedding for attention
y = (x + input_embedding.transpose(0, 1)) * math.sqrt(0.5)
return {
"encoder_out": (x, y),
"encoder_padding_mask": encoder_padding_mask, # B x T
}
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out["encoder_out"] = tuple(
eo.index_select(0, new_order) for eo in encoder_out["encoder_out"]
)
if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(0, new_order)
if "pretrained" in encoder_out:
encoder_out["pretrained"]["encoder_out"] = tuple(
eo.index_select(0, new_order)
for eo in encoder_out["pretrained"]["encoder_out"]
)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
return self.embed_positions.max_positions
@with_incremental_state
class FConvDecoder(FairseqDecoder):
"""Convolutional decoder"""
def __init__(
self,
dictionary,
embed_dim=512,
out_embed_dim=256,
max_positions=1024,
convolutions=((512, 3),) * 8,
attention=True,
dropout=0.1,
selfattention=False,
attention_nheads=1,
selfattention_nheads=1,
project_input=False,
gated_attention=False,
downsample=False,
pretrained=False,
trained_decoder=None,
):
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([2]))
self.pretrained = pretrained
self.pretrained_decoder = trained_decoder
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.need_attn = True
in_channels = convolutions[0][0]
def expand_bool_array(val):
if isinstance(val, bool):
# expand True into [True, True, ...] and do the same with False
return [val] * len(convolutions)
return val
attention = expand_bool_array(attention)
selfattention = expand_bool_array(selfattention)
if not isinstance(attention, list) or len(attention) != len(convolutions):
raise ValueError(
"Attention is expected to be a list of booleans of "
"length equal to the number of layers."
)
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = PositionalEmbedding(
max_positions,
embed_dim,
padding_idx,
)
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
self.projections = nn.ModuleList()
self.convolutions = nn.ModuleList()
self.attention = nn.ModuleList()
self.selfattention = nn.ModuleList()
self.attproj = nn.ModuleList()
for i, (out_channels, kernel_size) in enumerate(convolutions):
self.projections.append(
Linear(in_channels, out_channels)
if in_channels != out_channels
else None
)
self.convolutions.append(
LinearizedConv1d(
in_channels,
out_channels * 2,
kernel_size,
padding=(kernel_size - 1),
dropout=dropout,
)
)
self.attention.append(
DownsampledMultiHeadAttention(
out_channels,
embed_dim,
attention_nheads,
project_input=project_input,
gated=False,
downsample=False,
)
if attention[i]
else None
)
self.attproj.append(
Linear(out_channels, embed_dim, dropout=dropout)
if attention[i]
else None
)
self.selfattention.append(
SelfAttention(
out_channels,
embed_dim,
selfattention_nheads,
project_input=project_input,
gated=gated_attention,
downsample=downsample,
)
if selfattention[i]
else None
)
in_channels = out_channels
self.fc2 = Linear(in_channels, out_embed_dim)
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
# model fusion
if self.pretrained:
# independent gates are learned from the concatenated input
self.gate1 = nn.Sequential(
Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid()
)
self.gate2 = nn.Sequential(
Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid()
)
# pretrained and trained models are joined
self.joining = nn.Sequential(
Linear(out_embed_dim * 2, out_embed_dim * 2),
LayerNorm(out_embed_dim * 2),
nn.GLU(),
Linear(out_embed_dim, out_embed_dim * 2),
LayerNorm(out_embed_dim * 2),
nn.GLU(),
Linear(out_embed_dim, out_embed_dim),
LayerNorm(out_embed_dim),
)
# pretrained model contains an output layer that is nhid -> vocab size
# but the models are combined in their hidden state
# the hook stores the output of the pretrained model forward
self.pretrained_outputs = {}
def save_output():
def hook(a, b, output):
self.pretrained_outputs["out"] = output
return hook
self.pretrained_decoder.fc2.register_forward_hook(save_output())
def forward(self, prev_output_tokens, encoder_out):
trained_encoder_out = encoder_out["pretrained"] if self.pretrained else None
encoder_out = encoder_out["encoder"]["encoder_out"]
encoder_a, encoder_b = self._split_encoder_out(encoder_out)
# embed positions
positions = self.embed_positions(prev_output_tokens)
# embed tokens and positions
x = self.embed_tokens(prev_output_tokens) + positions
x = self.dropout_module(x)
target_embedding = x.transpose(0, 1)
# project to size of convolution
x = self.fc1(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# temporal convolutions
avg_attn_scores = None
for proj, conv, attention, selfattention, attproj in zip(
self.projections,
self.convolutions,
self.attention,
self.selfattention,
self.attproj,
):
residual = x if proj is None else proj(x)
x = self.dropout_module(x)
x = conv(x)
x = F.glu(x, dim=2)
# attention
if attention is not None:
r = x
x, attn_scores = attention(
attproj(x) + target_embedding, encoder_a, encoder_b
)
x = x + r
if not self.training and self.need_attn:
if avg_attn_scores is None:
avg_attn_scores = attn_scores
else:
avg_attn_scores.add_(attn_scores)
if selfattention is not None:
x = selfattention(x)
x = (x + residual) * math.sqrt(0.5)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
# project back to size of vocabulary
x = self.fc2(x)
x = self.dropout_module(x)
if not self.pretrained:
x = self.fc3(x)
# fusion gating
if self.pretrained:
trained_x, _ = self.pretrained_decoder.forward(
prev_output_tokens, trained_encoder_out
)
y = torch.cat([x, self.pretrained_outputs["out"]], dim=-1)
gate1 = self.gate1(y)
gate2 = self.gate2(y)
gated_x1 = gate1 * x
gated_x2 = gate2 * self.pretrained_outputs["out"]
fusion = torch.cat([gated_x1, gated_x2], dim=-1)
fusion = self.joining(fusion)
fusion_output = self.fc3(fusion)
return fusion_output, avg_attn_scores
else:
return x, avg_attn_scores
def max_positions(self):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def _split_encoder_out(self, encoder_out):
"""Split and transpose encoder outputs."""
# transpose only once to speed up attention layers
encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(0, 1).contiguous()
encoder_b = encoder_b.transpose(0, 1).contiguous()
result = (encoder_a, encoder_b)
return result
class SelfAttention(nn.Module):
def __init__(
self,
out_channels,
embed_dim,
num_heads,
project_input=False,
gated=False,
downsample=False,
):
super().__init__()
self.attention = DownsampledMultiHeadAttention(
out_channels,
embed_dim,
num_heads,
dropout=0,
bias=True,
project_input=project_input,
gated=gated,
downsample=downsample,
)
self.in_proj_q = Linear(out_channels, embed_dim)
self.in_proj_k = Linear(out_channels, embed_dim)
self.in_proj_v = Linear(out_channels, embed_dim)
self.ln = LayerNorm(out_channels)
def forward(self, x):
residual = x
query = self.in_proj_q(x)
key = self.in_proj_k(x)
value = self.in_proj_v(x)
x, _ = self.attention(
query, key, value, mask_future_timesteps=True, use_scalar_bias=True
)
return self.ln(x + residual)
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.normal_(0, 0.1)
return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx):
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
m.weight.data.normal_(0, 0.1)
return m
def Linear(in_features, out_features, dropout=0.0):
"""Weight-normalized Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features)
m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features))
m.bias.data.zero_()
return m
def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs):
"""Weight-normalized Conv1d layer optimized for decoding"""
m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
m.weight.data.normal_(mean=0, std=std)
m.bias.data.zero_()
return m
def ConvTBC(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs):
"""Weight-normalized Conv1d layer"""
from fairseq.modules import ConvTBC
m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
m.weight.data.normal_(mean=0, std=std)
m.bias.data.zero_()
return m
@register_model_architecture("fconv_self_att", "fconv_self_att")
def base_architecture(args):
args.dropout = getattr(args, "dropout", 0.1)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_layers = getattr(args, "encoder_layers", "[(512, 3)] * 3")
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_layers = getattr(args, "decoder_layers", "[(512, 3)] * 8")
args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256)
args.decoder_attention = getattr(args, "decoder_attention", "True")
args.self_attention = getattr(args, "self_attention", "False")
args.encoder_attention = getattr(args, "encoder_attention", "False")
args.multihead_attention_nheads = getattr(args, "multihead_attention_nheads", 1)
args.multihead_self_attention_nheads = getattr(
args, "multihead_self_attention_nheads", 1
)
args.encoder_attention_nheads = getattr(args, "encoder_attention_nheads", 1)
args.project_input = getattr(args, "project_input", "False")
args.gated_attention = getattr(args, "gated_attention", "False")
args.downsample = getattr(args, "downsample", "False")
args.pretrained_checkpoint = getattr(args, "pretrained_checkpoint", "")
args.pretrained = getattr(args, "pretrained", "False")
@register_model_architecture("fconv_self_att", "fconv_self_att_wp")
def fconv_self_att_wp(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_layers = getattr(
args, "encoder_layers", "[(128, 3)] * 2 + [(512,3)] * 1"
)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
args.decoder_layers = getattr(
args, "decoder_layers", "[(512, 4)] * 4 + [(768, 4)] * 2 + [(1024, 4)] * 1"
)
args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256)
args.self_attention = getattr(args, "self_attention", "True")
args.multihead_self_attention_nheads = getattr(
args, "multihead_self_attention_nheads", 4
)
args.project_input = getattr(args, "project_input", "True")
args.gated_attention = getattr(args, "gated_attention", "True")
args.downsample = getattr(args, "downsample", "True")
base_architecture(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .hubert import * # noqa
from .hubert_asr import * # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from omegaconf import II
from fairseq import utils
from fairseq.data.data_utils import compute_mask_indices
from fairseq.data.dictionary import Dictionary
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models.wav2vec.wav2vec2 import (
EXTRACTOR_MODE_CHOICES,
MASKING_DISTRIBUTION_CHOICES,
LAYER_TYPE_CHOICES,
ConvFeatureExtractionModel,
TransformerEncoder,
)
from fairseq.modules import GradMultiply, LayerNorm
from fairseq.tasks.hubert_pretraining import (
HubertPretrainingConfig,
HubertPretrainingTask,
)
logger = logging.getLogger(__name__)
@dataclass
class HubertConfig(FairseqDataclass):
label_rate: float = II("task.label_rate")
extractor_mode: EXTRACTOR_MODE_CHOICES = field(
default="default",
metadata={
"help": "mode for feature extractor. default has a single group "
"norm with d groups in the first conv block, whereas layer_norm "
"has layer norms in every block (meant to use with normalize=True)"
},
)
encoder_layers: int = field(
default=12, metadata={"help": "num encoder layers in the transformer"}
)
encoder_embed_dim: int = field(
default=768, metadata={"help": "encoder embedding dimension"}
)
encoder_ffn_embed_dim: int = field(
default=3072, metadata={"help": "encoder embedding dimension for FFN"}
)
encoder_attention_heads: int = field(
default=12, metadata={"help": "num encoder attention heads"}
)
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
default="gelu", metadata={"help": "activation function to use"}
)
layer_type: LAYER_TYPE_CHOICES = field(
default="transformer", metadata={"help": "layer type in encoder"}
)
# dropouts
dropout: float = field(
default=0.1,
metadata={"help": "dropout probability for the transformer"},
)
attention_dropout: float = field(
default=0.1,
metadata={"help": "dropout probability for attention weights"},
)
activation_dropout: float = field(
default=0.0,
metadata={"help": "dropout probability after activation in FFN"},
)
encoder_layerdrop: float = field(
default=0.0,
metadata={"help": "probability of dropping a tarnsformer layer"},
)
dropout_input: float = field(
default=0.0,
metadata={"help": "dropout to apply to the input (after feat extr)"},
)
dropout_features: float = field(
default=0.0,
metadata={"help": "dropout to apply to the features (after feat extr)"},
)
final_dim: int = field(
default=0,
metadata={
"help": "project final representations and targets to this many "
"dimensions. set to encoder_embed_dim is <= 0"
},
)
untie_final_proj: bool = field(
default=False,
metadata={"help": "use separate projection for each target"},
)
layer_norm_first: bool = field(
default=False,
metadata={"help": "apply layernorm first in the transformer"},
)
conv_feature_layers: str = field(
default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
metadata={
"help": "string describing convolutional feature extraction "
"layers in form of a python list that contains "
"[(dim, kernel_size, stride), ...]"
},
)
conv_bias: bool = field(
default=False, metadata={"help": "include bias in conv encoder"}
)
logit_temp: float = field(
default=0.1, metadata={"help": "temperature to divide logits by"}
)
target_glu: bool = field(
default=False, metadata={"help": "adds projection + glu to targets"}
)
feature_grad_mult: float = field(
default=1.0,
metadata={"help": "multiply feature extractor var grads by this"},
)
# masking
mask_length: int = field(default=10, metadata={"help": "mask length"})
mask_prob: float = field(
default=0.65,
metadata={"help": "probability of replacing a token with mask"},
)
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
default="static", metadata={"help": "how to choose mask length"}
)
mask_other: float = field(
default=0,
metadata={
"help": "secondary mask argument "
"(used for more complex distributions), "
"see help in compute_mask_indicesh"
},
)
no_mask_overlap: bool = field(
default=False, metadata={"help": "whether to allow masks to overlap"}
)
mask_min_space: int = field(
default=1,
metadata={"help": "min space between spans (if no overlap is enabled)"},
)
# channel masking
mask_channel_length: int = field(
default=10,
metadata={"help": "length of the mask for features (channels)"},
)
mask_channel_prob: float = field(
default=0.0,
metadata={"help": "probability of replacing a feature with 0"},
)
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
default="static",
metadata={"help": "how to choose mask length for channel masking"},
)
mask_channel_other: float = field(
default=0,
metadata={
"help": "secondary mask argument "
"(used for more complex distributions), "
"see help in compute_mask_indicesh"
},
)
no_mask_channel_overlap: bool = field(
default=False,
metadata={"help": "whether to allow channel masks to overlap"},
)
mask_channel_min_space: int = field(
default=1,
metadata={"help": "min space between spans (if no overlap is enabled)"},
)
# positional embeddings
conv_pos: int = field(
default=128,
metadata={"help": "number of filters for convolutional positional embeddings"},
)
conv_pos_groups: int = field(
default=16,
metadata={"help": "number of groups for convolutional positional embedding"},
)
latent_temp: Tuple[float, float, float] = field(
default=(2, 0.5, 0.999995),
metadata={"help": "legacy (to be removed)"},
)
# loss computation
skip_masked: bool = field(
default=False,
metadata={"help": "skip computing losses over masked frames"},
)
skip_nomask: bool = field(
default=False,
metadata={"help": "skip computing losses over unmasked frames"},
)
checkpoint_activations: bool = field(
default=False,
metadata={"help": "recompute activations and save memory for extra compute"},
)
# FP16 optimization
required_seq_len_multiple: int = field(
default=2,
metadata={
"help": "pad the input to encoder such that the sequence length is divisible by multiple"
},
)
# Conformer
depthwise_conv_kernel_size: int = field(
default=31,
metadata={
"help": "depthwise-conv-kernel-size for convolution in conformer layer"
},
)
attn_type: str = field(
default="",
metadata={"help": "if espnet use ESPNET MHA"},
)
pos_enc_type: str = field(
default="abs",
metadata={"help": "Positional encoding type to use in conformer"},
)
fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"})
@register_model("hubert", dataclass=HubertConfig)
class HubertModel(BaseFairseqModel):
def __init__(
self,
cfg: HubertConfig,
task_cfg: HubertPretrainingConfig,
dictionaries: List[Dictionary],
) -> None:
super().__init__()
logger.info(f"HubertModel Config: {cfg}")
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
self.embed = feature_enc_layers[-1][0]
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
mode=cfg.extractor_mode,
conv_bias=cfg.conv_bias,
)
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim
else None
)
self.mask_prob = cfg.mask_prob
self.mask_selection = cfg.mask_selection
self.mask_other = cfg.mask_other
self.mask_length = cfg.mask_length
self.no_mask_overlap = cfg.no_mask_overlap
self.mask_min_space = cfg.mask_min_space
self.mask_channel_prob = cfg.mask_channel_prob
self.mask_channel_selection = cfg.mask_channel_selection
self.mask_channel_other = cfg.mask_channel_other
self.mask_channel_length = cfg.mask_channel_length
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
self.mask_channel_min_space = cfg.mask_channel_min_space
self.dropout_input = nn.Dropout(cfg.dropout_input)
self.dropout_features = nn.Dropout(cfg.dropout_features)
self.feature_grad_mult = cfg.feature_grad_mult
self.logit_temp = cfg.logit_temp
self.skip_masked = cfg.skip_masked
self.skip_nomask = cfg.skip_nomask
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
self.mask_emb = nn.Parameter(
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
)
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.embed)
self.target_glu = None
if cfg.target_glu:
self.target_glu = nn.Sequential(
nn.Linear(final_dim, final_dim * 2), nn.GLU()
)
self.untie_final_proj = cfg.untie_final_proj
if self.untie_final_proj:
self.final_proj = nn.Linear(
cfg.encoder_embed_dim, final_dim * len(dictionaries)
)
else:
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
# modules below are not needed during fine-tuning
if any([d is None for d in dictionaries]):
logger.info("cannot find dictionary. assume will be used for fine-tuning")
else:
self.num_classes = [len(d) for d in dictionaries]
self.label_embs_concat = nn.Parameter(
torch.FloatTensor(sum(self.num_classes), final_dim)
)
nn.init.uniform_(self.label_embs_concat)
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, cfg: HubertConfig, task: HubertPretrainingTask):
"""Build a new model instance."""
model = HubertModel(cfg, task.cfg, task.dictionaries)
return model
def apply_mask(self, x, padding_mask, target_list):
B, T, C = x.shape
if self.mask_prob > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x[mask_indices] = self.mask_emb
else:
mask_indices = None
if self.mask_channel_prob > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
return x, mask_indices
def compute_nce(self, x, pos, negs):
neg_is_pos = (pos == negs).all(-1)
pos = pos.unsqueeze(0)
targets = torch.cat([pos, negs], dim=0)
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
logits /= self.logit_temp
if neg_is_pos.any():
logits[1:][neg_is_pos] = float("-inf")
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
return logits
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
if self.feature_grad_mult > 0:
features = self.feature_extractor(source)
if self.feature_grad_mult != 1.0:
features = GradMultiply.apply(features, self.feature_grad_mult)
else:
with torch.no_grad():
features = self.feature_extractor(source)
return features
def forward_targets(
self,
features: torch.Tensor,
target_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Trim features to ensure labels exist and then get aligned labels
feat_tsz = features.size(2)
targ_tsz = min([t.size(1) for t in target_list])
if self.feat2tar_ratio * feat_tsz > targ_tsz:
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
features = features[..., :feat_tsz]
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
target_list = [t[:, target_inds.long()] for t in target_list]
return features, target_list
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
return padding_mask
def forward(
self,
source: torch.Tensor,
target_list: Optional[List[torch.Tensor]] = None,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = True,
features_only: bool = False,
output_layer: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
"""output layer is 1-based"""
features = self.forward_features(source)
if target_list is not None:
features, target_list = self.forward_targets(features, target_list)
features_pen = features.float().pow(2).mean()
features = features.transpose(1, 2)
features = self.layer_norm(features)
unmasked_features = features.clone()
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)
if mask:
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
else:
x = features
mask_indices = None
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x, _ = self.encoder(
x,
padding_mask=padding_mask,
layer=None if output_layer is None else output_layer - 1,
)
if features_only:
return {"x": x, "padding_mask": padding_mask, "features": features}
def compute_pred(proj_x, target, label_embs):
# compute logits for the i-th label set
y = torch.index_select(label_embs, 0, target.long())
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
if self.target_glu:
y = self.target_glu(y)
negs = self.target_glu(negs)
# proj_x: (S, D)
# y: (S, D)
# negs: (Neg, S, D)
return self.compute_nce(proj_x, y, negs)
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
if not self.skip_masked:
masked_indices = torch.logical_and(~padding_mask, mask_indices)
proj_x_m = self.final_proj(x[masked_indices])
if self.untie_final_proj:
proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
else:
proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
logit_m_list = [
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list))
]
else:
logit_m_list = [None for _ in target_list]
if not self.skip_nomask:
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
proj_x_u = self.final_proj(x[nomask_indices])
if self.untie_final_proj:
proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)
else:
proj_x_u_list = [proj_x_u for _ in range(len(target_list))]
logit_u_list = [
compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list))
]
else:
logit_u_list = [None for _ in target_list]
result = {
"logit_m_list": logit_m_list,
"logit_u_list": logit_u_list,
"padding_mask": padding_mask,
"features_pen": features_pen,
}
return result
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = False,
ret_conv: bool = False,
output_layer: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
res = self.forward(
source,
padding_mask=padding_mask,
mask=mask,
features_only=True,
output_layer=output_layer,
)
feature = res["features"] if ret_conv else res["x"]
return feature, res["padding_mask"]
def get_logits(self, net_output, is_masked=True):
if is_masked:
logits_list = net_output["logit_m_list"]
else:
logits_list = net_output["logit_u_list"]
logits_list = [x.float() for x in logits_list if x is not None]
return logits_list
def get_targets(self, net_output, is_masked=True):
logits_list = self.get_logits(net_output, is_masked)
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
return targets_list
def get_extra_losses(self, net_output):
extra_losses = []
names = []
if "features_pen" in net_output:
extra_losses.append(net_output["features_pen"])
names.append("features_pen")
return extra_losses, names
def remove_pretraining_modules(self):
self.target_glu = None
self.final_proj = None
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Any
import torch
import torch.nn as nn
from omegaconf import II, MISSING
from fairseq import checkpoint_utils, tasks, utils
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model
from fairseq.models.hubert.hubert import MASKING_DISTRIBUTION_CHOICES
from fairseq.tasks import FairseqTask
@dataclass
class HubertAsrConfig(FairseqDataclass):
w2v_path: str = field(default=MISSING, metadata={"help": "path to hubert model"})
no_pretrained_weights: bool = field(
default=False,
metadata={"help": "if true, does not load pretrained weights"},
)
dropout_input: float = field(
default=0.0,
metadata={"help": "dropout to apply to the input (after feat extr)"},
)
final_dropout: float = field(
default=0.0,
metadata={"help": "dropout after transformer and before final projection"},
)
dropout: float = field(
default=0.0,
metadata={"help": "dropout probability inside hubert model"},
)
attention_dropout: float = field(
default=0.0,
metadata={
"help": "dropout probability for attention weights " "inside hubert model"
},
)
activation_dropout: float = field(
default=0.0,
metadata={
"help": "dropout probability after activation in FFN " "inside hubert model"
},
)
# masking
apply_mask: bool = field(
default=False, metadata={"help": "apply masking during fine-tuning"}
)
mask_length: int = field(
default=10, metadata={"help": "repeat the mask indices multiple times"}
)
mask_prob: float = field(
default=0.5,
metadata={
"help": "probability of replacing a token with mask "
"(normalized by length)"
},
)
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
default="static", metadata={"help": "how to choose masks"}
)
mask_other: float = field(
default=0,
metadata={
"help": "secondary mask argument "
"(used for more complex distributions), "
"see help in compute_mask_indices"
},
)
no_mask_overlap: bool = field(
default=False, metadata={"help": "whether to allow masks to overlap"}
)
# channel masking
mask_channel_length: int = field(
default=10,
metadata={"help": "length of the mask for features (channels)"},
)
mask_channel_prob: float = field(
default=0.0,
metadata={"help": "probability of replacing a feature with 0"},
)
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
default="static",
metadata={"help": "how to choose mask length for channel masking"},
)
mask_channel_other: float = field(
default=0,
metadata={
"help": "secondary mask argument "
"(used for more complex distributions), "
"see help in compute_mask_indices"
},
)
no_mask_channel_overlap: bool = field(
default=False,
metadata={"help": "whether to allow channel masks to overlap"},
)
freeze_finetune_updates: int = field(
default=0,
metadata={"help": "dont finetune hubert for this many updates"},
)
feature_grad_mult: float = field(
default=0.0,
metadata={"help": "reset feature grad mult in hubert to this"},
)
layerdrop: float = field(
default=0.0,
metadata={"help": "probability of dropping a layer in hubert"},
)
normalize: bool = II("task.normalize")
data: str = II("task.data")
# this holds the loaded hubert args
w2v_args: Any = None
@dataclass
class HubertCtcConfig(HubertAsrConfig):
pass
@register_model("hubert_ctc", dataclass=HubertCtcConfig)
class HubertCtc(BaseFairseqModel):
def __init__(self, cfg: HubertCtcConfig, w2v_encoder: BaseFairseqModel):
super().__init__()
self.cfg = cfg
self.w2v_encoder = w2v_encoder
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, cfg: HubertCtcConfig, task: FairseqTask):
"""Build a new model instance."""
w2v_encoder = HubertEncoder(cfg, task)
return cls(cfg, w2v_encoder)
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output["encoder_out"]
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def get_logits(self, net_output):
logits = net_output["encoder_out"]
padding = net_output["encoder_padding_mask"]
if padding is not None and padding.any():
padding = padding.T
logits[padding][..., 0] = 0
logits[padding][..., 1:] = float("-inf")
return logits
def forward(self, **kwargs):
x = self.w2v_encoder(**kwargs)
return x
@dataclass
class HubertSeq2SeqConfig(HubertAsrConfig):
decoder_embed_dim: int = field(
default=768, metadata={"help": "decoder embedding dimension"}
)
decoder_ffn_embed_dim: int = field(
default=3072, metadata={"help": "decoder embedding dimension for FFN"}
)
decoder_layers: int = field(default=6, metadata={"help": "num of decoder layers"})
decoder_layerdrop: float = field(
default=0.0, metadata={"help": "decoder layerdrop chance"}
)
decoder_attention_heads: int = field(
default=4, metadata={"help": "num decoder attention heads"}
)
decoder_learned_pos: bool = field(
default=False,
metadata={"help": "use learned positional embeddings in the decoder"},
)
decoder_normalize_before: bool = field(
default=False,
metadata={"help": "apply layernorm before each decoder block"},
)
no_token_positional_embeddings: bool = field(
default=False,
metadata={
"help": "if set, disables positional embeddings " "(outside self attention)"
},
)
decoder_dropout: float = field(
default=0.0, metadata={"help": "dropout probability in the decoder"}
)
decoder_attention_dropout: float = field(
default=0.0,
metadata={
"help": "dropout probability for attention weights " "inside the decoder"
},
)
decoder_activation_dropout: float = field(
default=0.0,
metadata={
"help": "dropout probability after activation in FFN " "inside the decoder"
},
)
max_target_positions: int = field(
default=2048, metadata={"help": "max target positions"}
)
share_decoder_input_output_embed: bool = field(
default=False,
metadata={"help": "share decoder input and output embeddings"},
)
class HubertEncoder(FairseqEncoder):
def __init__(self, cfg: HubertAsrConfig, task):
self.apply_mask = cfg.apply_mask
arg_overrides = {
"dropout": cfg.dropout,
"activation_dropout": cfg.activation_dropout,
"dropout_input": cfg.dropout_input,
"attention_dropout": cfg.attention_dropout,
"mask_length": cfg.mask_length,
"mask_prob": cfg.mask_prob,
"mask_selection": cfg.mask_selection,
"mask_other": cfg.mask_other,
"no_mask_overlap": cfg.no_mask_overlap,
"mask_channel_length": cfg.mask_channel_length,
"mask_channel_prob": cfg.mask_channel_prob,
"mask_channel_selection": cfg.mask_channel_selection,
"mask_channel_other": cfg.mask_channel_other,
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
"encoder_layerdrop": cfg.layerdrop,
"feature_grad_mult": cfg.feature_grad_mult,
}
if cfg.w2v_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
w2v_args = state.get("cfg", None)
if w2v_args is None:
w2v_args = convert_namespace_to_omegaconf(state["args"])
cfg.w2v_args = w2v_args
else:
state = None
w2v_args = cfg.w2v_args
if isinstance(w2v_args, Namespace):
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
assert cfg.normalize == w2v_args.task.normalize, (
"Fine-tuning works best when data normalization is the same. "
"Please check that --normalize is set or unset for "
"both pre-training and here"
)
w2v_args.task.data = cfg.data
pretrain_task = tasks.setup_task(w2v_args.task)
if state is not None and "task_state" in state:
# This will load the stored "dictionaries" object
pretrain_task.load_state_dict(state["task_state"])
else:
pretrain_task.load_state_dict(task.state_dict())
model = pretrain_task.build_model(w2v_args.model, from_checkpoint=True)
if state is not None and not cfg.no_pretrained_weights:
# set strict=False because we omit some modules
model.load_state_dict(state["model"], strict=False)
model.remove_pretraining_modules()
super().__init__(pretrain_task.source_dictionary)
d = w2v_args.model.encoder_embed_dim
self.w2v_model = model
self.final_dropout = nn.Dropout(cfg.final_dropout)
self.freeze_finetune_updates = cfg.freeze_finetune_updates
self.num_updates = 0
if task.target_dictionary is not None:
self.proj = Linear(d, len(task.target_dictionary))
elif getattr(cfg, "decoder_embed_dim", d) != d:
self.proj = Linear(d, cfg.decoder_embed_dim)
else:
self.proj = None
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
super().set_num_updates(num_updates)
self.num_updates = num_updates
def forward(self, source, padding_mask, tbc=True, **kwargs):
w2v_args = {
"source": source,
"padding_mask": padding_mask,
"mask": self.apply_mask and self.training,
}
ft = self.freeze_finetune_updates <= self.num_updates
with torch.no_grad() if not ft else contextlib.ExitStack():
x, padding_mask = self.w2v_model.extract_features(**w2v_args)
if tbc:
# B x T x C -> T x B x C
x = x.transpose(0, 1)
x = self.final_dropout(x)
if self.proj:
x = self.proj(x)
return {
"encoder_out": x, # T x B x C
"encoder_padding_mask": padding_mask, # B x T
"padding_mask": padding_mask,
}
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out["encoder_out"] is not None:
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(0, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
return None
def upgrade_state_dict_named(self, state_dict, name):
return state_dict
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.0)
return m
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
# automatically import any Python files in the models/huggingface/ directory
models_dir = os.path.dirname(__file__)
for file in os.listdir(models_dir):
path = os.path.join(models_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
model_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("fairseq.models.huggingface." + model_name)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import sys
from typing import Dict, List, Optional
import torch
from fairseq.models import (
FairseqIncrementalDecoder,
FairseqLanguageModel,
register_model,
register_model_architecture,
)
logger = logging.getLogger(__name__)
DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model("hf_gpt2")
class HuggingFaceGPT2LanguageModel(FairseqLanguageModel):
def __init__(self, decoder):
super().__init__(decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--embed-dim', type=int, metavar='N',
help='embedding dimension')
parser.add_argument('--num-attention-heads', type=int, metavar='N',
help='num attention heads')
parser.add_argument('--num-layers', type=int, metavar='N',
help='num layers')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability for all fully connected layers '
'in the embeddings, encoder, and pooler')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
# fmt: on
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
default_architecture(args)
return cls(HuggingFaceGPT2Decoder(args, task))
class HuggingFaceGPT2Decoder(FairseqIncrementalDecoder):
def __init__(self, args, task):
try:
from transformers import GPT2Config, GPT2LMHeadModel
except ImportError:
raise ImportError(
"\n\nPlease install huggingface/transformers with:"
"\n\n pip install transformers"
)
super().__init__(task.target_dictionary)
config = GPT2Config(
vocab_size=len(task.target_dictionary),
n_positions=args.max_target_positions + 1,
n_ctx=args.max_target_positions,
n_embd=args.embed_dim,
n_layer=args.num_layers,
n_head=args.num_attention_heads,
resid_pdrop=args.dropout,
embd_pdrop=args.dropout,
attn_pdrop=args.attention_dropout,
layer_norm_epsilon=1e-6,
)
self.model = GPT2LMHeadModel(config)
# set zero embedding for padding symbol
self.pad_idx = task.target_dictionary.pad()
self.model.transformer.wte.weight.data[self.pad_idx].zero_()
self.model.transformer.wpe.weight.data[0].zero_()
def forward(
self,
prev_output_tokens,
src_lengths=None,
incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
encoder_out=None,
):
features = self.extract_features(prev_output_tokens, incremental_state)
lm_logits = self.model.lm_head(features)
return (lm_logits,)
def extract_features(
self,
prev_output_tokens,
incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
):
if incremental_state:
past = self.get_incremental_state("past")
else:
past = None
# don't attend to padding symbols
attention_mask = prev_output_tokens.ne(self.pad_idx).int()
# set position ids to exclude padding symbols
position_ids = attention_mask * (
torch.arange(1, 1 + prev_output_tokens.size(1))
.to(prev_output_tokens)
.repeat(prev_output_tokens.size(0), 1)
)
outputs = self.model.transformer(
input_ids=prev_output_tokens,
past=past,
attention_mask=attention_mask,
position_ids=position_ids,
)
last_hidden_states = outputs[0]
if incremental_state:
self.set_incremental_state(incremental_state, "past", outputs[1])
return last_hidden_states
def max_positions(self):
return self.model.config.n_positions - 1
@register_model_architecture("hf_gpt2", "hf_gpt2")
def default_architecture(args):
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = getattr(
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
)
args.embed_dim = getattr(args, "embed_dim", 768)
args.num_attention_heads = getattr(args, "num_attention_heads", 12)
args.num_layers = getattr(args, "num_layers", 12)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
@register_model_architecture("hf_gpt2", "hf_gpt2_medium")
def hf_gpt2_medium(args):
args.embed_dim = getattr(args, "embed_dim", 1024)
args.num_attention_heads = getattr(args, "num_attention_heads", 16)
args.num_layers = getattr(args, "num_layers", 24)
default_architecture(args)
@register_model_architecture("hf_gpt2", "hf_gpt2_large")
def hf_gpt2_large(args):
args.embed_dim = getattr(args, "embed_dim", 1280)
args.num_attention_heads = getattr(args, "num_attention_heads", 20)
args.num_layers = getattr(args, "num_layers", 36)
default_architecture(args)
@register_model_architecture("hf_gpt2", "hf_gpt2_xl")
def hf_gpt2_xl(args):
args.embed_dim = getattr(args, "embed_dim", 1600)
args.num_attention_heads = getattr(args, "num_attention_heads", 25)
args.num_layers = getattr(args, "num_layers", 48)
default_architecture(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment