# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F from examples.simultaneous_translation.modules.monotonic_transformer_layer import ( TransformerMonotonicDecoderLayer, TransformerMonotonicEncoderLayer, ) from fairseq.models import register_model, register_model_architecture from fairseq.models.transformer import ( TransformerDecoder, TransformerEncoder, TransformerModel, base_architecture, transformer_iwslt_de_en, transformer_vaswani_wmt_en_de_big, ) DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 @register_model("transformer_unidirectional") class TransformerUnidirectionalModel(TransformerModel): @classmethod def build_encoder(cls, args, src_dict, embed_tokens): return TransformerMonotonicEncoder(args, src_dict, embed_tokens) @register_model("transformer_monotonic") class TransformerMonotonicModel(TransformerModel): @classmethod def build_encoder(cls, args, src_dict, embed_tokens): return TransformerMonotonicEncoder(args, src_dict, embed_tokens) @classmethod def build_decoder(cls, args, tgt_dict, embed_tokens): return TransformerMonotonicDecoder(args, tgt_dict, embed_tokens) def _indices_from_states(self, states): if type(states["indices"]["src"]) == list: if next(self.parameters()).is_cuda: tensor = torch.cuda.LongTensor else: tensor = torch.LongTensor src_indices = tensor( [states["indices"]["src"][: 1 + states["steps"]["src"]]] ) tgt_indices = tensor( [[self.decoder.dictionary.eos()] + states["indices"]["tgt"]] ) else: src_indices = states["indices"]["src"][: 1 + states["steps"]["src"]] tgt_indices = states["indices"]["tgt"] return src_indices, None, tgt_indices def predict_from_states(self, states): decoder_states = self.decoder.output_layer(states["decoder_features"]) lprobs = self.get_normalized_probs([decoder_states[:, -1:]], log_probs=True) index = lprobs.argmax(dim=-1) token = self.decoder.dictionary.string(index) return token, index[0, 0].item() def decision_from_states(self, states): """ This funcion take states dictionary as input, and gives the agent a decision of whether read a token from server. Moreover, the decoder states are also calculated here so we can directly generate a target token without recompute every thing """ self.eval() if len(states["tokens"]["src"]) == 0: return 0 src_indices, src_lengths, tgt_indices = self._indices_from_states(states) # Update encoder states if needed if ( "encoder_states" not in states or states["encoder_states"][0].size(1) <= states["steps"]["src"] ): encoder_out_dict = self.encoder(src_indices, src_lengths) states["encoder_states"] = encoder_out_dict else: encoder_out_dict = states["encoder_states"] # online means we still need tokens to feed the model states["model_states"]["online"] = not ( states["finish_read"] and len(states["tokens"]["src"]) == states["steps"]["src"] ) states["model_states"]["steps"] = states["steps"] x, outputs = self.decoder.forward( prev_output_tokens=tgt_indices, encoder_out=encoder_out_dict, incremental_state=states["model_states"], features_only=True, ) states["decoder_features"] = x return outputs["action"] class TransformerMonotonicEncoder(TransformerEncoder): def __init__(self, args, dictionary, embed_tokens): super().__init__(args, dictionary, embed_tokens) self.dictionary = dictionary self.layers = nn.ModuleList([]) self.layers.extend( [TransformerMonotonicEncoderLayer(args) for i in range(args.encoder_layers)] ) class TransformerMonotonicDecoder(TransformerDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): super().__init__(args, dictionary, embed_tokens, no_encoder_attn=False) self.dictionary = dictionary self.layers = nn.ModuleList([]) self.layers.extend( [ TransformerMonotonicDecoderLayer(args, no_encoder_attn) for _ in range(args.decoder_layers) ] ) def pre_attention( self, prev_output_tokens, encoder_out_dict, incremental_state=None ): positions = ( self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if self.embed_positions is not None else None ) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions x = self.dropout_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) encoder_out = encoder_out_dict.encoder_out encoder_padding_mask = encoder_out_dict.encoder_padding_mask return x, encoder_out, encoder_padding_mask def post_attention(self, x): if self.layer_norm: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) return x def extract_features( self, prev_output_tokens, encoder_out, incremental_state=None, **unused ): """ Similar to *forward* but only return features. Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ # incremental_state = None (x, encoder_outs, encoder_padding_mask) = self.pre_attention( prev_output_tokens, encoder_out, incremental_state ) attn = None inner_states = [x] attn_list = [] step_list = [] for i, layer in enumerate(self.layers): x, attn, _ = layer( x=x, encoder_out=encoder_outs, encoder_padding_mask=encoder_padding_mask, incremental_state=incremental_state, self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, ) inner_states.append(x) attn_list.append(attn) if incremental_state is not None: curr_steps = layer.get_steps(incremental_state) step_list.append(curr_steps) if incremental_state.get("online", False): p_choose = ( attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t()) ) new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps) if (new_steps >= incremental_state["steps"]["src"]).any(): # We need to prune the last self_attn saved_state # if model decide not to read # otherwise there will be duplicated saved_state for j in range(i + 1): self.layers[j].prune_incremental_state(incremental_state) return x, {"action": 0} if incremental_state is not None and not incremental_state.get("online", False): # Here is for fast evaluation fastest_step = ( torch.max(torch.cat(step_list, dim=1), dim=1, keepdim=True)[0] + 1 ) if "fastest_step" in incremental_state: incremental_state["fastest_step"] = torch.cat( [incremental_state["fastest_step"], fastest_step], dim=1 ) else: incremental_state["fastest_step"] = fastest_step x = self.post_attention(x) return x, { "action": 1, "attn_list": attn_list, "step_list": step_list, "encoder_out": encoder_out, "encoder_padding_mask": encoder_padding_mask, } def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) if "fastest_step" in incremental_state: incremental_state["fastest_step"] = incremental_state[ "fastest_step" ].index_select(0, new_order) @register_model_architecture("transformer_monotonic", "transformer_monotonic") def base_monotonic_rchitecture(args): base_architecture(args) args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False) @register_model_architecture( "transformer_monotonic", "transformer_monotonic_iwslt_de_en" ) def transformer_monotonic_iwslt_de_en(args): transformer_iwslt_de_en(args) base_monotonic_rchitecture(args) # parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017) @register_model_architecture( "transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_de_big" ) def transformer_monotonic_vaswani_wmt_en_de_big(args): transformer_vaswani_wmt_en_de_big(args) @register_model_architecture( "transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_fr_big" ) def transformer_monotonic_vaswani_wmt_en_fr_big(args): transformer_monotonic_vaswani_wmt_en_fr_big(args) @register_model_architecture( "transformer_unidirectional", "transformer_unidirectional_iwslt_de_en" ) def transformer_unidirectional_iwslt_de_en(args): transformer_iwslt_de_en(args)