# Copyright (c) 2017 Elad Hoffer # Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import torch.nn as nn from torch.nn.functional import log_softmax class Seq2Seq(nn.Module): """ Generic Seq2Seq module, with an encoder and a decoder. """ def __init__(self, encoder=None, decoder=None, batch_first=False): """ Constructor for the Seq2Seq module. :param encoder: encoder module :param decoder: decoder module :param batch_first: if True the model uses (batch, seq, feature) tensors, if false the model uses (seq, batch, feature) tensors """ super(Seq2Seq, self).__init__() self.encoder = encoder self.decoder = decoder self.batch_first = batch_first def encode(self, inputs, lengths): """ Applies the encoder to inputs with a given input sequence lengths. :param inputs: tensor with inputs (batch, seq_len) if 'batch_first' else (seq_len, batch) :param lengths: vector with sequence lengths (excluding padding) """ return self.encoder(inputs, lengths) def decode(self, inputs, context, inference=False): """ Applies the decoder to inputs, given the context from the encoder. :param inputs: tensor with inputs (batch, seq_len) if 'batch_first' else (seq_len, batch) :param context: context from the encoder :param inference: if True inference mode, if False training mode """ return self.decoder(inputs, context, inference) def generate(self, inputs, context, beam_size): """ Autoregressive generator, works with SequenceGenerator class. Executes decoder (in inference mode), applies log_softmax and topK for inference with beam search decoding. :param inputs: tensor with inputs to the decoder :param context: context from the encoder :param beam_size: beam size for the generator returns: (words, logprobs, scores, new_context) words: indices of topK tokens logprobs: log probabilities of topK tokens scores: scores from the attention module (for coverage penalty) new_context: new decoder context, includes new hidden states for decoder RNN cells """ logits, scores, new_context = self.decode(inputs, context, True) logprobs = log_softmax(logits, dim=-1) logprobs, words = logprobs.topk(beam_size, dim=-1) return words, logprobs, scores, new_context