Commit 8d9063fe authored by Kartikay Khandelwal's avatar Kartikay Khandelwal Committed by Facebook Github Bot
Browse files

Mask out embeddings associated with padding (#710)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/710

Previously there was a bug in how we dealt with padding when computing the input representation from the segment and position embedding. D15144912 fixed this by adding an offset based on the padding id. However this makes assumptions about the padding id which may not hold true for vocabularies built outside of pyText and fairseq. Based on a discussion with barlaso, this diff 0's out all the embeddings associated with the padding.

Reviewed By: borguz

Differential Revision: D15209395

fbshipit-source-id: 5573020e610f5466e673fe3845c3ed34ebb5c44d
parent 0add50c2
...@@ -100,7 +100,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -100,7 +100,7 @@ class TransformerSentenceEncoder(nn.Module):
) )
self.segment_embeddings = ( self.segment_embeddings = (
nn.Embedding(self.num_segments, self.embedding_dim, self.padding_idx) nn.Embedding(self.num_segments, self.embedding_dim, padding_idx=None)
if self.num_segments > 0 if self.num_segments > 0
else None else None
) )
...@@ -110,7 +110,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -110,7 +110,7 @@ class TransformerSentenceEncoder(nn.Module):
self.max_seq_len, self.max_seq_len,
self.embedding_dim, self.embedding_dim,
self.padding_idx, self.padding_idx,
self.learned_pos_embedding, learned=self.learned_pos_embedding,
) )
if self.use_position_embeddings if self.use_position_embeddings
else None else None
...@@ -162,12 +162,17 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -162,12 +162,17 @@ class TransformerSentenceEncoder(nn.Module):
) )
x = self.embed_tokens(tokens) x = self.embed_tokens(tokens)
if positions is not None: if positions is not None:
x += positions x += positions
if segments is not None: if segments is not None:
x += segments x += segments
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
# account for padding while computing the representation
if padding_mask is not None:
x *= (1 - padding_mask.unsqueeze(-1).float())
# B x T x C -> T x B x C # B x T x C -> T x B x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
inner_states = [x] inner_states = [x]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment