Commit 7df61696 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add fairseq0.10.2

parents
Pipeline #471 failed with stages
in 0 seconds
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .hub_interface import * # noqa
from .model import * # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
from typing import List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import encoders
logger = logging.getLogger(__name__)
class BARTHubInterface(nn.Module):
"""A simple PyTorch Hub interface to BART.
Usage: https://github.com/pytorch/fairseq/tree/master/examples/bart
"""
def __init__(self, args, task, model):
super().__init__()
self.args = args
self.task = task
self.model = model
self.bpe = encoders.build_bpe(args)
self.max_positions = min(
utils.resolve_max_positions(
self.task.max_positions(),
self.model.max_positions(),
)
)
# this is useful for determining the device
self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))
@property
def device(self):
return self._float_tensor.device
def encode(
self, sentence: str, *addl_sentences, no_separator=True
) -> torch.LongTensor:
"""
BPE-encode a sentence (or multiple sentences).
Every sequence begins with a beginning-of-sentence (`<s>`) symbol.
Every sentence ends with an end-of-sentence (`</s>`).
Example (single sentence): `<s> a b c </s>`
Example (sentence pair): `<s> d e f </s> 1 2 3 </s>`
The BPE encoding follows GPT-2. One subtle detail is that the GPT-2 BPE
requires leading spaces. For example::
>>> bart.encode('Hello world').tolist()
[0, 31414, 232, 2]
>>> bart.encode(' world').tolist()
[0, 232, 2]
>>> bart.encode('world').tolist()
[0, 8331, 2]
"""
tokens = self.bpe.encode(sentence)
if len(tokens.split(" ")) > self.max_positions - 2:
tokens = " ".join(tokens.split(" ")[: self.max_positions - 2])
bpe_sentence = "<s> " + tokens + " </s>"
for s in addl_sentences:
bpe_sentence += " </s>" if not no_separator else ""
bpe_sentence += " " + self.bpe.encode(s) + " </s>"
tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False)
return tokens.long()
def decode(self, tokens: torch.LongTensor):
assert tokens.dim() == 1
tokens = tokens.cpu().numpy()
if tokens[0] == self.task.source_dictionary.bos():
tokens = tokens[1:] # remove <s>
eos_mask = tokens == self.task.source_dictionary.eos()
doc_mask = eos_mask[1:] & eos_mask[:-1]
sentences = np.split(tokens, doc_mask.nonzero()[0] + 1)
sentences = [
self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences
]
if len(sentences) == 1:
return sentences[0]
return sentences
def _build_sample(self, src_tokens: List[torch.LongTensor]):
# assert torch.is_tensor(src_tokens)
dataset = self.task.build_dataset_for_inference(
src_tokens,
[x.numel() for x in src_tokens],
)
sample = dataset.collater(dataset)
sample = utils.apply_to_sample(lambda tensor: tensor.to(self.device), sample)
return sample
def sample(
self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs
) -> str:
input = [self.encode(sentence) for sentence in sentences]
hypos = self.generate(input, beam, verbose, **kwargs)
return [self.decode(x["tokens"]) for x in hypos]
def generate(
self,
tokens: List[torch.LongTensor],
beam: int = 5,
verbose: bool = False,
**kwargs
) -> torch.LongTensor:
sample = self._build_sample(tokens)
# build generator using current args as well as any kwargs
gen_args = copy.copy(self.args)
gen_args.beam = beam
for k, v in kwargs.items():
setattr(gen_args, k, v)
generator = self.task.build_generator([self.model], gen_args)
translations = self.task.inference_step(
generator,
[self.model],
sample,
prefix_tokens=sample["net_input"]["src_tokens"]
.new_zeros((len(tokens), 1))
.fill_(self.task.source_dictionary.bos()),
)
if verbose:
src_str_with_unk = self.string(tokens)
logger.info("S\t{}".format(src_str_with_unk))
def getarg(name, default):
return getattr(gen_args, name, getattr(self.args, name, default))
# Process top predictions
hypos = [x[0] for x in translations]
hypos = [v for _, v in sorted(zip(sample["id"].tolist(), hypos))]
return hypos
def extract_features(
self, tokens: torch.LongTensor, return_all_hiddens: bool = False
) -> torch.Tensor:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
if tokens.size(-1) > min(self.model.max_positions()):
raise ValueError(
"tokens exceeds maximum length: {} > {}".format(
tokens.size(-1), self.model.max_positions()
)
)
tokens.to(device=self.device),
prev_output_tokens = tokens.clone()
prev_output_tokens[:, 0] = tokens.gather(
1,
(tokens.ne(self.task.source_dictionary.pad()).sum(dim=1) - 1).unsqueeze(-1),
).squeeze()
prev_output_tokens[:, 1:] = tokens[:, :-1]
features, extra = self.model(
src_tokens=tokens,
src_lengths=None,
prev_output_tokens=prev_output_tokens,
features_only=True,
return_all_hiddens=return_all_hiddens,
)
if return_all_hiddens:
# convert from T x B x C -> B x T x C
inner_states = extra["inner_states"]
return [inner_state.transpose(0, 1) for inner_state in inner_states]
else:
return features # just the last layer's features
def register_classification_head(
self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs
):
self.model.register_classification_head(
name, num_classes=num_classes, embedding_size=embedding_size, **kwargs
)
def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False):
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
features = self.extract_features(tokens.to(device=self.device))
sentence_representation = features[
tokens.eq(self.task.source_dictionary.eos()), :
].view(features.size(0), -1, features.size(-1))[:, -1, :]
logits = self.model.classification_heads[head](sentence_representation)
if return_logits:
return logits
return F.log_softmax(logits, dim=-1)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
BART: Denoising Sequence-to-Sequence Pre-training for
Natural Language Generation, Translation, and Comprehension
"""
import logging
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import TransformerModel
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from .hub_interface import BARTHubInterface
logger = logging.getLogger(__name__)
@register_model("bart")
class BARTModel(TransformerModel):
@classmethod
def hub_models(cls):
return {
"bart.base": "http://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz",
"bart.large": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz",
"bart.large.mnli": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz",
"bart.large.cnn": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz",
"bart.large.xsum": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz",
}
def __init__(self, args, encoder, decoder):
super().__init__(args, encoder, decoder)
# We follow BERT's random weight initialization
self.apply(init_bert_params)
self.classification_heads = nn.ModuleDict()
@staticmethod
def add_args(parser):
super(BARTModel, BARTModel).add_args(parser)
parser.add_argument(
"--pooler-dropout",
type=float,
metavar="D",
help="dropout probability in the masked_lm pooler layers",
)
parser.add_argument(
"--pooler-activation-fn",
choices=utils.get_available_activation_fns(),
help="activation function to use for pooler layer",
)
parser.add_argument(
"--spectral-norm-classification-head",
action="store_true",
help="Apply spectral normalization on the classification head",
)
@property
def supported_targets(self):
return {"self"}
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens,
features_only=False,
classification_head_name=None,
token_embeddings=None,
**kwargs,
):
if classification_head_name is not None:
features_only = True
encoder_out = self.encoder(
src_tokens,
src_lengths=src_lengths,
token_embeddings=token_embeddings,
**kwargs,
)
x, extra = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
features_only=features_only,
**kwargs,
)
if classification_head_name is not None:
sentence_representation = x[
src_tokens.eq(self.encoder.dictionary.eos()), :
].view(x.size(0), -1, x.size(-1))[:, -1, :]
x = self.classification_heads[classification_head_name](
sentence_representation
)
return x, extra
@classmethod
def from_pretrained(
cls,
model_name_or_path,
checkpoint_file="model.pt",
data_name_or_path=".",
bpe="gpt2",
**kwargs,
):
from fairseq import hub_utils
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
bpe=bpe,
load_checkpoint_heads=True,
**kwargs,
)
return BARTHubInterface(x["args"], x["task"], x["models"][0])
def register_classification_head(
self, name, num_classes=None, inner_dim=None, **kwargs
):
"""Register a classification head."""
logger.info("Registering classification head: {0}".format(name))
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
self.classification_heads[name] = BARTClassificationHead(
input_dim=self.args.encoder_embed_dim,
inner_dim=inner_dim or self.args.encoder_embed_dim,
num_classes=num_classes,
activation_fn=self.args.pooler_activation_fn,
pooler_dropout=self.args.pooler_dropout,
do_spectral_norm=self.args.spectral_norm_classification_head,
)
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
prefix = name + "." if name != "" else ""
current_head_names = (
[]
if not hasattr(self, "classification_heads")
else self.classification_heads.keys()
)
# Handle new classification heads present in the state dict.
keys_to_delete = []
for k in state_dict.keys():
if not k.startswith(prefix + "classification_heads."):
continue
head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
num_classes = state_dict[
prefix + "classification_heads." + head_name + ".out_proj.weight"
].size(0)
inner_dim = state_dict[
prefix + "classification_heads." + head_name + ".dense.weight"
].size(0)
if getattr(self.args, "load_checkpoint_heads", False):
if head_name not in current_head_names:
self.register_classification_head(head_name, num_classes, inner_dim)
else:
if head_name not in current_head_names:
logger.warning(
"deleting classification head ({}) from checkpoint "
"not present in current model: {}".format(head_name, k)
)
keys_to_delete.append(k)
elif (
num_classes
!= self.classification_heads[head_name].out_proj.out_features
or inner_dim
!= self.classification_heads[head_name].dense.out_features
):
logger.warning(
"deleting classification head ({}) from checkpoint "
"with different dimensions than current model: {}".format(
head_name, k
)
)
keys_to_delete.append(k)
for k in keys_to_delete:
del state_dict[k]
def truncate_emb(key):
if key in state_dict:
state_dict[key] = state_dict[key][:-1, :]
# When finetuning on translation task, remove last row of
# embedding matrix that corresponds to mask_idx token.
loaded_dict_size = state_dict["encoder.embed_tokens.weight"].size(0)
if (
loaded_dict_size == len(self.encoder.dictionary) + 1
and "<mask>" not in self.encoder.dictionary
):
truncate_emb("encoder.embed_tokens.weight")
truncate_emb("decoder.embed_tokens.weight")
truncate_emb("encoder.output_projection.weight")
truncate_emb("decoder.output_projection.weight")
# When continued pretraining on new set of languages for mbart,
# add extra lang embeddings at the end of embed_tokens.
# Note: newly added languages are assumed to have been added at the end.
if self.args.task == "multilingual_denoising" and loaded_dict_size < len(
self.encoder.dictionary
):
logger.info(
"Adding extra language embeddings not found in pretrained model for "
"continued pretraining of MBART on new set of languages."
)
loaded_mask_token_embedding = state_dict["encoder.embed_tokens.weight"][
-1, :
]
num_langids_to_add = len(self.encoder.dictionary) - loaded_dict_size
embed_dim = state_dict["encoder.embed_tokens.weight"].size(1)
new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim)
nn.init.normal_(new_lang_embed_to_add, mean=0, std=embed_dim ** -0.5)
new_lang_embed_to_add = new_lang_embed_to_add.to(
dtype=state_dict["encoder.embed_tokens.weight"].dtype,
)
state_dict["encoder.embed_tokens.weight"] = torch.cat(
[
state_dict["encoder.embed_tokens.weight"][
: loaded_dict_size - 1, :
],
new_lang_embed_to_add,
loaded_mask_token_embedding.unsqueeze(0),
]
)
state_dict["decoder.embed_tokens.weight"] = torch.cat(
[
state_dict["decoder.embed_tokens.weight"][
: loaded_dict_size - 1, :
],
new_lang_embed_to_add,
loaded_mask_token_embedding.unsqueeze(0),
]
)
# Copy any newly-added classification heads into the state dict
# with their current weights.
if hasattr(self, "classification_heads"):
cur_state = self.classification_heads.state_dict()
for k, v in cur_state.items():
if prefix + "classification_heads." + k not in state_dict:
logger.info("Overwriting", prefix + "classification_heads." + k)
state_dict[prefix + "classification_heads." + k] = v
class BARTClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(
self,
input_dim,
inner_dim,
num_classes,
activation_fn,
pooler_dropout,
do_spectral_norm=False,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
if do_spectral_norm:
self.out_proj = torch.nn.utils.spectral_norm(self.out_proj)
def forward(self, features, **kwargs):
x = features
x = self.dropout(x)
x = self.dense(x)
x = self.activation_fn(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
@register_model_architecture("bart", "bart_large")
def bart_large_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 1024)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 12)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.relu_dropout = getattr(args, "relu_dropout", 0.0)
args.dropout = getattr(args, "dropout", 0.1)
args.max_target_positions = getattr(args, "max_target_positions", 1024)
args.max_source_positions = getattr(args, "max_source_positions", 1024)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", True
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", True)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
@register_model_architecture("bart", "bart_base")
def bart_base_architecture(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 768)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
bart_large_architecture(args)
@register_model_architecture("bart", "mbart_large")
def mbart_large_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
bart_large_architecture(args)
@register_model_architecture("bart", "mbart_base")
def mbart_base_architecture(args):
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
bart_base_architecture(args)
@register_model_architecture("bart", "mbart_base_wmt20")
def mbart_base_wmt20_architecture(args):
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
mbart_base_architecture(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .fairseq_encoder import FairseqEncoder
class CompositeEncoder(FairseqEncoder):
"""
A wrapper around a dictionary of :class:`FairseqEncoder` objects.
We run forward on each encoder and return a dictionary of outputs. The first
encoder's dictionary is used for initialization.
Args:
encoders (dict): a dictionary of :class:`FairseqEncoder` objects.
"""
def __init__(self, encoders):
super().__init__(next(iter(encoders.values())).dictionary)
self.encoders = encoders
for key in self.encoders:
self.add_module(key, self.encoders[key])
def forward(self, src_tokens, src_lengths):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (LongTensor): lengths of each source sentence of shape
`(batch)`
Returns:
dict:
the outputs from each Encoder
"""
encoder_out = {}
for key in self.encoders:
encoder_out[key] = self.encoders[key](src_tokens, src_lengths)
return encoder_out
def reorder_encoder_out(self, encoder_out, new_order):
"""Reorder encoder output according to new_order."""
for key in self.encoders:
encoder_out[key] = self.encoders[key].reorder_encoder_out(
encoder_out[key], new_order
)
return encoder_out
def max_positions(self):
return min(self.encoders[key].max_positions() for key in self.encoders)
def upgrade_state_dict(self, state_dict):
for key in self.encoders:
self.encoders[key].upgrade_state_dict(state_dict)
return state_dict
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import inspect
import torch.nn as nn
from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
_GOSSIP_DISABLED = False
try:
import gossip
except ImportError:
_GOSSIP_DISABLED = True
def DistributedFairseqModel(args, model, process_group=None):
"""
Wrap a *model* to support distributed data parallel training.
This is similar to the built-in DistributedDataParallel, but allows
additional configuration of the DistributedDataParallel class to
use, and also provides easier access to the wrapped model by
forwarding requests for missing attributes to the wrapped model.
Args:
args (argparse.Namespace): fairseq args
model (BaseFairseqModel): model to wrap
"""
# determine which DDP class to extend
assert isinstance(model, nn.Module)
if args.distributed_wrapper == "DDP" and args.ddp_backend == "c10d":
ddp_class = nn.parallel.DistributedDataParallel
init_kwargs = dict(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=args.broadcast_buffers,
bucket_cap_mb=args.bucket_cap_mb,
process_group=process_group,
)
# Maintain backward compatibility
if "check_reduction" in inspect.getargspec(ddp_class)[0]:
init_kwargs["check_reduction"] = True
if "find_unused_parameters" in inspect.getargspec(ddp_class)[0]:
init_kwargs["find_unused_parameters"] = args.find_unused_parameters
elif args.distributed_wrapper == "DDP" and args.ddp_backend == "no_c10d":
ddp_class = LegacyDistributedDataParallel
init_kwargs = dict(
module=model,
world_size=args.distributed_world_size,
buffer_size=2 ** 28,
process_group=process_group,
)
elif args.distributed_wrapper == "SlowMo":
if _GOSSIP_DISABLED:
raise ImportError(
"Cannot find gossip library. Please install from: "
"github.com/facebookresearch/stochastic_gradient_push"
)
ddp_class = gossip.GossipDataParallel
# The values of slowmo_momentum below were obtained by tuning on the
# En-De 16 dataset by training the transformer_wmt_en_de_large model
if args.slowmo_momentum is None:
if args.distributed_world_size <= 16:
args.slowmo_momentum = 0.0
elif args.distributed_world_size <= 32:
args.slowmo_momentum = 0.2
elif args.distributed_world_size <= 64:
args.slowmo_momentum = 0.5
else:
args.slowmo_momentum = 0.6
init_kwargs = dict(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=args.broadcast_buffers,
nprocs_per_node=args.nprocs_per_node,
slowmo_momentum=args.slowmo_momentum,
localsgd=(args.slowmo_algorithm == "LocalSGD"),
localsgd_frequency=args.localsgd_frequency,
)
else:
raise ValueError("Unknown --ddp-backend: " + args.ddp_backend)
class _DistributedFairseqModel(ddp_class):
"""Extend DistributedDataParallel to check for missing
attributes in the wrapped module."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __getattr__(self, name):
wrapped_module = super().__getattr__("module")
if hasattr(wrapped_module, name):
return getattr(wrapped_module, name)
return super().__getattr__(name)
return _DistributedFairseqModel(**init_kwargs)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Tuple
import torch.nn as nn
from fairseq import utils
from torch import Tensor
class FairseqDecoder(nn.Module):
"""Base class for decoders."""
def __init__(self, dictionary):
super().__init__()
self.dictionary = dictionary
self.onnx_trace = False
def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
"""
Args:
prev_output_tokens (LongTensor): shifted output tokens of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (dict, optional): output from the encoder, used for
encoder-side attention
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(
prev_output_tokens, encoder_out=encoder_out, **kwargs
)
x = self.output_layer(x)
return x, extra
def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs):
"""
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
raise NotImplementedError
def output_layer(self, features, **kwargs):
"""
Project features to the default output size, e.g., vocabulary size.
Args:
features (Tensor): features returned by *extract_features*.
"""
raise NotImplementedError
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
if sample is not None:
assert "target" in sample
target = sample["target"]
else:
target = None
out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
return out.exp_() if not log_probs else out
logits = net_output[0]
if log_probs:
return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
else:
return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def max_positions(self):
"""Maximum input length supported by the decoder."""
return 1e6 # an arbitrary large number
def upgrade_state_dict(self, state_dict):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
return state_dict
def prepare_for_onnx_export_(self):
self.onnx_trace = True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, NamedTuple, Optional
import torch
import torch.nn as nn
from torch import Tensor
EncoderOut = NamedTuple(
"EncoderOut",
[
("encoder_out", Tensor), # T x B x C
("encoder_padding_mask", Optional[Tensor]), # B x T
("encoder_embedding", Optional[Tensor]), # B x T x C
("encoder_states", Optional[List[Tensor]]), # List[T x B x C]
("src_tokens", Optional[Tensor]), # B x T
("src_lengths", Optional[Tensor]), # B x 1
],
)
class FairseqEncoder(nn.Module):
"""Base class for encoders."""
def __init__(self, dictionary):
super().__init__()
self.dictionary = dictionary
def forward(self, src_tokens, src_lengths=None, **kwargs):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (LongTensor): lengths of each source sentence of shape
`(batch)`
"""
raise NotImplementedError
def forward_torchscript(self, net_input: Dict[str, Tensor]):
"""A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override
this method for TorchScript compatibility.
"""
if torch.jit.is_scripting():
return self.forward(
src_tokens=net_input["src_tokens"],
src_lengths=net_input["src_lengths"],
)
else:
return self.forward_non_torchscript(net_input)
@torch.jit.unused
def forward_non_torchscript(self, net_input: Dict[str, Tensor]):
encoder_input = {
k: v for k, v in net_input.items() if k != "prev_output_tokens"
}
return self.forward(**encoder_input)
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to `new_order`.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
`encoder_out` rearranged according to `new_order`
"""
raise NotImplementedError
def max_positions(self):
"""Maximum input length supported by the encoder."""
return 1e6 # an arbitrary large number
def upgrade_state_dict(self, state_dict):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
return state_dict
def set_num_updates(self, num_updates):
"""State from trainer to pass along to model at every update."""
def _apply(m):
if hasattr(m, "set_num_updates") and m != self:
m.set_num_updates(num_updates)
self.apply(_apply)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment