Commit 7bbe528d authored by Haoran Li's avatar Haoran Li Committed by Facebook Github Bot
Browse files

fixes on bi-transformer onnx

Summary: replace dynamic index put with copying and creating a new tensor

Reviewed By: wanchaol

Differential Revision: D13244573

fbshipit-source-id: 909f7913ad579ed035f29bb52321ff01e09a2c60
parent 866d0d2e
...@@ -32,6 +32,7 @@ class CharacterTokenEmbedder(torch.nn.Module): ...@@ -32,6 +32,7 @@ class CharacterTokenEmbedder(torch.nn.Module):
): ):
super(CharacterTokenEmbedder, self).__init__() super(CharacterTokenEmbedder, self).__init__()
self.onnx_trace = False
self.embedding_dim = word_embed_dim self.embedding_dim = word_embed_dim
self.max_char_len = max_char_len self.max_char_len = max_char_len
self.char_embeddings = nn.Embedding(257, char_embed_dim, padding_idx=0) self.char_embeddings = nn.Embedding(257, char_embed_dim, padding_idx=0)
...@@ -58,6 +59,9 @@ class CharacterTokenEmbedder(torch.nn.Module): ...@@ -58,6 +59,9 @@ class CharacterTokenEmbedder(torch.nn.Module):
self.reset_parameters() self.reset_parameters()
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def set_vocab(self, vocab, max_char_len): def set_vocab(self, vocab, max_char_len):
word_to_char = torch.LongTensor(len(vocab), max_char_len) word_to_char = torch.LongTensor(len(vocab), max_char_len)
...@@ -101,7 +105,11 @@ class CharacterTokenEmbedder(torch.nn.Module): ...@@ -101,7 +105,11 @@ class CharacterTokenEmbedder(torch.nn.Module):
pads = chars[:, 0].eq(CHAR_PAD_IDX) pads = chars[:, 0].eq(CHAR_PAD_IDX)
eos = chars[:, 0].eq(CHAR_EOS_IDX) eos = chars[:, 0].eq(CHAR_EOS_IDX)
if eos.any(): if eos.any():
if self.onnx_trace:
chars = torch.where(eos.unsqueeze(1), chars.new_zeros(1), chars)
else:
chars[eos] = 0 chars[eos] = 0
unk = None unk = None
else: else:
flat_words = input.view(-1) flat_words = input.view(-1)
...@@ -111,12 +119,18 @@ class CharacterTokenEmbedder(torch.nn.Module): ...@@ -111,12 +119,18 @@ class CharacterTokenEmbedder(torch.nn.Module):
unk = flat_words.eq(self.vocab.unk()) unk = flat_words.eq(self.vocab.unk())
word_embs = self._convolve(chars) word_embs = self._convolve(chars)
if self.onnx_trace:
if pads.any():
word_embs = torch.where(pads.unsqueeze(1), word_embs.new_zeros(1), word_embs)
if eos.any():
word_embs = torch.where(eos.unsqueeze(1), self.symbol_embeddings[self.eos_idx], word_embs)
if unk is not None and unk.any():
word_embs = torch.where(unk.unsqueeze(1), self.symbol_embeddings[self.unk_idx], word_embs)
else:
if pads.any(): if pads.any():
word_embs[pads] = 0 word_embs[pads] = 0
if eos.any(): if eos.any():
word_embs[eos] = self.symbol_embeddings[self.eos_idx] word_embs[eos] = self.symbol_embeddings[self.eos_idx]
if unk is not None and unk.any(): if unk is not None and unk.any():
word_embs[unk] = self.symbol_embeddings[self.unk_idx] word_embs[unk] = self.symbol_embeddings[self.unk_idx]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment