Commit 184629a7 authored by Juan Miguel Pino's avatar Juan Miguel Pino Committed by Facebook Github Bot
Browse files

Add onnx_trace argument for learned embeddings (#492)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/492

This argument was missing so we cannot export Transformer if we use learned positional embeddings. See also https://github.com/pytorch/translate/pull/335

Reviewed By: jhcross

Differential Revision: D13984781

fbshipit-source-id: 2187377e952ff587e07237de312c5b68f7d68891
parent fbd4cef9
......@@ -20,6 +20,7 @@ class LearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.left_pad = left_pad
self.onnx_trace = False
def forward(self, input, incremental_state=None):
"""Input is expected to be of size [bsz x seqlen]."""
......@@ -27,7 +28,7 @@ class LearnedPositionalEmbedding(nn.Embedding):
# positions is the same for every token when decoding a single step
positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
else:
positions = utils.make_positions(input.data, self.padding_idx, self.left_pad)
positions = utils.make_positions(input.data, self.padding_idx, self.left_pad, self.onnx_trace)
return super().forward(positions)
def max_positions(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment