"vscode:/vscode.git/clone" did not exist on "d87ce2cefc6612fa95cb6d58fa3d74080d18b312"
Commit c2820af0 authored by Kartikay Khandelwal's avatar Kartikay Khandelwal Committed by Facebook Github Bot
Browse files

Rename embedding layers to be the same as NMT (#628)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/628

Updating embedding layers in TransformerSentenceEncoder to be compatible with the transformer model.

Reviewed By: liezl200

Differential Revision: D14836883

fbshipit-source-id: 2240f61bf40b191d01b4efdaac4dd7562b4166c6
parent 94e9d77c
......@@ -106,7 +106,7 @@ class TransformerSentenceEncoder(nn.Module):
self.use_position_embeddings = use_position_embeddings
self.apply_bert_init = apply_bert_init
self.token_embeddings = nn.Embedding(
self.embed_tokens = nn.Embedding(
self.vocab_size, self.embedding_dim, self.padding_idx
)
......@@ -116,7 +116,7 @@ class TransformerSentenceEncoder(nn.Module):
else None
)
self.position_embeddings = (
self.embed_positions = (
PositionalEmbedding(
self.max_seq_len,
self.embedding_dim,
......@@ -161,8 +161,8 @@ class TransformerSentenceEncoder(nn.Module):
# embed positions
positions = (
self.position_embeddings(tokens)
if self.position_embeddings is not None else None
self.embed_positions(tokens)
if self.embed_positions is not None else None
)
# embed segments
......@@ -172,7 +172,7 @@ class TransformerSentenceEncoder(nn.Module):
else None
)
x = self.token_embeddings(tokens)
x = self.embed_tokens(tokens)
if positions is not None:
x += positions
if segments is not None:
......
......@@ -51,7 +51,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
# Initialize blocks
self.activation_fn = gelu if use_gelu else F.relu
self.self_attention = MultiheadAttention(
self.self_attn = MultiheadAttention(
self.embedding_dim, num_attention_heads, dropout=attention_dropout
)
......@@ -97,7 +97,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
residual = x
x = self._maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
x, attn = self.self_attention(
x, attn = self.self_attn(
query=x,
key=x,
value=x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment