Commit c2820af0 authored by Kartikay Khandelwal's avatar Kartikay Khandelwal Committed by Facebook Github Bot
Browse files

Rename embedding layers to be the same as NMT (#628)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/628

Updating embedding layers in TransformerSentenceEncoder to be compatible with the transformer model.

Reviewed By: liezl200

Differential Revision: D14836883

fbshipit-source-id: 2240f61bf40b191d01b4efdaac4dd7562b4166c6
parent 94e9d77c
...@@ -106,7 +106,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -106,7 +106,7 @@ class TransformerSentenceEncoder(nn.Module):
self.use_position_embeddings = use_position_embeddings self.use_position_embeddings = use_position_embeddings
self.apply_bert_init = apply_bert_init self.apply_bert_init = apply_bert_init
self.token_embeddings = nn.Embedding( self.embed_tokens = nn.Embedding(
self.vocab_size, self.embedding_dim, self.padding_idx self.vocab_size, self.embedding_dim, self.padding_idx
) )
...@@ -116,7 +116,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -116,7 +116,7 @@ class TransformerSentenceEncoder(nn.Module):
else None else None
) )
self.position_embeddings = ( self.embed_positions = (
PositionalEmbedding( PositionalEmbedding(
self.max_seq_len, self.max_seq_len,
self.embedding_dim, self.embedding_dim,
...@@ -161,8 +161,8 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -161,8 +161,8 @@ class TransformerSentenceEncoder(nn.Module):
# embed positions # embed positions
positions = ( positions = (
self.position_embeddings(tokens) self.embed_positions(tokens)
if self.position_embeddings is not None else None if self.embed_positions is not None else None
) )
# embed segments # embed segments
...@@ -172,7 +172,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -172,7 +172,7 @@ class TransformerSentenceEncoder(nn.Module):
else None else None
) )
x = self.token_embeddings(tokens) x = self.embed_tokens(tokens)
if positions is not None: if positions is not None:
x += positions x += positions
if segments is not None: if segments is not None:
......
...@@ -51,7 +51,7 @@ class TransformerSentenceEncoderLayer(nn.Module): ...@@ -51,7 +51,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
# Initialize blocks # Initialize blocks
self.activation_fn = gelu if use_gelu else F.relu self.activation_fn = gelu if use_gelu else F.relu
self.self_attention = MultiheadAttention( self.self_attn = MultiheadAttention(
self.embedding_dim, num_attention_heads, dropout=attention_dropout self.embedding_dim, num_attention_heads, dropout=attention_dropout
) )
...@@ -97,7 +97,7 @@ class TransformerSentenceEncoderLayer(nn.Module): ...@@ -97,7 +97,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
residual = x residual = x
x = self._maybe_layer_norm(self.self_attn_layer_norm, x, before=True) x = self._maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
x, attn = self.self_attention( x, attn = self.self_attn(
query=x, query=x,
key=x, key=x,
value=x, value=x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment