# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from fairseq.modules import LayerNorm, TransformerDecoderLayer, TransformerEncoderLayer from . import build_monotonic_attention class TransformerMonotonicEncoderLayer(TransformerEncoderLayer): def forward(self, x, encoder_padding_mask): seq_len, _, _ = x.size() attn_mask = x.new_ones([seq_len, seq_len]).triu(1) attn_mask = attn_mask.masked_fill(attn_mask.bool(), float("-inf")) return super().forward(x, encoder_padding_mask, attn_mask) class TransformerMonotonicDecoderLayer(TransformerDecoderLayer): def __init__( self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False ): super().__init__( args, no_encoder_attn=True, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, ) self.encoder_attn = build_monotonic_attention(args) self.encoder_attn_layer_norm = LayerNorm( self.embed_dim, export=getattr(args, "char_inputs", False) ) def prune_incremental_state(self, incremental_state): def prune(module): input_buffer = module._get_input_buffer(incremental_state) for key in ["prev_key", "prev_value"]: if input_buffer[key].size(2) > 1: input_buffer[key] = input_buffer[key][:, :, :-1, :] else: input_buffer = {} break module._set_input_buffer(incremental_state, input_buffer) prune(self.self_attn) def get_steps(self, incremental_state): return self.encoder_attn._get_monotonic_buffer(incremental_state).get("step", 0)