"cacheflow/master/server.py" did not exist on "e9d3f2ff7772c8efe41dc805cec71c223ec18ec8"
Commit a0c5f9b8 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Allow TransformerSentenceEncoder to return only last state

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/578

Differential Revision: D15352060

Pulled By: myleott

fbshipit-source-id: 7dc2fceca37ec96c89356662831b0d82f28bef6f
parent 52778827
......@@ -152,7 +152,8 @@ class TransformerSentenceEncoder(nn.Module):
def forward(
self,
tokens: torch.Tensor,
segment_labels: torch.Tensor
segment_labels: torch.Tensor,
last_state_only: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
# compute padding mask. This is needed for multi-head attention
......@@ -181,18 +182,25 @@ class TransformerSentenceEncoder(nn.Module):
# B x T x C -> T x B x C
x = x.transpose(0, 1)
inner_states = [x]
inner_states = []
if not last_state_only:
inner_states.append(x)
for layer in self.layers:
x, _ = layer(
x,
self_attn_padding_mask=padding_mask,
)
inner_states.append(x)
if not last_state_only:
inner_states.append(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
sentence_rep = x[:, 0, :]
if last_state_only:
inner_states = [x]
return inner_states, sentence_rep
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment