Commit c6d4386c authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

Fix embedding initialization for padding

parent 1ec5f0a0
...@@ -416,13 +416,15 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -416,13 +416,15 @@ class FConvDecoder(FairseqIncrementalDecoder):
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.normal_(0, 0.1) nn.init.normal(m.weight, 0, 0.1)
nn.init.constant(m.weight[padding_idx], 0)
return m return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad): def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad):
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad) m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
m.weight.data.normal_(0, 0.1) nn.init.normal(m.weight, 0, 0.1)
nn.init.constant(m.weight[padding_idx], 0)
return m return m
......
...@@ -403,7 +403,8 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -403,7 +403,8 @@ class LSTMDecoder(FairseqIncrementalDecoder):
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.uniform_(-0.1, 0.1) nn.init.uniform(m.weight, -0.1, 0.1)
nn.init.constant(m.weight[padding_idx], 0)
return m return m
......
...@@ -379,6 +379,7 @@ def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, le ...@@ -379,6 +379,7 @@ def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, le
if learned: if learned:
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad) m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
nn.init.normal(m.weight, mean=0, std=embedding_dim**-0.5) nn.init.normal(m.weight, mean=0, std=embedding_dim**-0.5)
nn.init.constant(m.weight[padding_idx], 0)
else: else:
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, init_size=num_embeddings) m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, init_size=num_embeddings)
return m return m
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment