Commit a0c5f9b8 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Allow TransformerSentenceEncoder to return only last state

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/578

Differential Revision: D15352060

Pulled By: myleott

fbshipit-source-id: 7dc2fceca37ec96c89356662831b0d82f28bef6f
parent 52778827
...@@ -152,7 +152,8 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -152,7 +152,8 @@ class TransformerSentenceEncoder(nn.Module):
def forward( def forward(
self, self,
tokens: torch.Tensor, tokens: torch.Tensor,
segment_labels: torch.Tensor segment_labels: torch.Tensor,
last_state_only: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# compute padding mask. This is needed for multi-head attention # compute padding mask. This is needed for multi-head attention
...@@ -181,13 +182,17 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -181,13 +182,17 @@ class TransformerSentenceEncoder(nn.Module):
# B x T x C -> T x B x C # B x T x C -> T x B x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
inner_states = [x]
inner_states = []
if not last_state_only:
inner_states.append(x)
for layer in self.layers: for layer in self.layers:
x, _ = layer( x, _ = layer(
x, x,
self_attn_padding_mask=padding_mask, self_attn_padding_mask=padding_mask,
) )
if not last_state_only:
inner_states.append(x) inner_states.append(x)
# T x B x C -> B x T x C # T x B x C -> B x T x C
...@@ -195,4 +200,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -195,4 +200,7 @@ class TransformerSentenceEncoder(nn.Module):
sentence_rep = x[:, 0, :] sentence_rep = x[:, 0, :]
if last_state_only:
inner_states = [x]
return inner_states, sentence_rep return inner_states, sentence_rep
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment