Commit 9398a282 authored by Xilun Chen's avatar Xilun Chen Committed by Facebook Github Bot
Browse files

Add an extract_features option for LSTMDecoder

Summary:
This diff adds a new option to the LSTMDecoder to obtain unprojected decoder outputs (before the final output projection layer).

The original forward() method remains unchanged, but is divided into two parts: extract_features() and output_layer().

extract_features() outputs the hidden states of the decoder, which offers more flexibility to the model.

For instance, the unprojected decoder outputs are needed to implement a copy pointer attention that uses the decoder output to determine whether to copy certain tokens from the source sequence.

Reviewed By: myleott

Differential Revision: D18650255

fbshipit-source-id: 321c3085676d98b8b4f4ad6102917c94800643a5
parent 9f6b552d
......@@ -356,6 +356,17 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
x, attn_scores = self.extract_features(
prev_output_tokens, encoder_out, incremental_state
)
return self.output_layer(x), attn_scores
def extract_features(
self, prev_output_tokens, encoder_out, incremental_state=None
):
"""
Similar to *forward* but only return features.
"""
encoder_padding_mask = encoder_out['encoder_padding_mask']
encoder_out = encoder_out['encoder_out']
......@@ -429,22 +440,25 @@ class LSTMDecoder(FairseqIncrementalDecoder):
# T x B x C -> B x T x C
x = x.transpose(1, 0)
if hasattr(self, 'additional_fc') and self.adaptive_softmax is None:
x = self.additional_fc(x)
x = F.dropout(x, p=self.dropout_out, training=self.training)
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen
if not self.training and self.need_attn:
attn_scores = attn_scores.transpose(0, 2)
else:
attn_scores = None
return x, attn_scores
# project back to size of vocabulary
def output_layer(self, x):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
if hasattr(self, 'additional_fc'):
x = self.additional_fc(x)
x = F.dropout(x, p=self.dropout_out, training=self.training)
if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight)
else:
x = self.fc_out(x)
return x, attn_scores
return x
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment