Commit 4afa455e authored by Haoran Li's avatar Haoran Li Committed by Facebook Github Bot
Browse files

make fairseq models compatible with character inputs and use character inputs for elmo in pytext

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/321

Reviewed By: alexeib

Differential Revision: D10430186

fbshipit-source-id: 9cc8fe0f202cc49370cecf36312bcc9bf0b4deee
parent 613ffeea
...@@ -17,6 +17,9 @@ from typing import List, Tuple ...@@ -17,6 +17,9 @@ from typing import List, Tuple
from .highway import Highway from .highway import Highway
from fairseq.data import Dictionary from fairseq.data import Dictionary
CHAR_PAD_IDX = 0
CHAR_EOS_IDX = 257
class CharacterTokenEmbedder(torch.nn.Module): class CharacterTokenEmbedder(torch.nn.Module):
def __init__( def __init__(
...@@ -27,13 +30,16 @@ class CharacterTokenEmbedder(torch.nn.Module): ...@@ -27,13 +30,16 @@ class CharacterTokenEmbedder(torch.nn.Module):
word_embed_dim: int, word_embed_dim: int,
highway_layers: int, highway_layers: int,
max_char_len: int = 50, max_char_len: int = 50,
char_inputs: bool = False
): ):
super(CharacterTokenEmbedder, self).__init__() super(CharacterTokenEmbedder, self).__init__()
self.embedding_dim = word_embed_dim self.embedding_dim = word_embed_dim
self.max_char_len = max_char_len
self.char_embeddings = nn.Embedding(257, char_embed_dim, padding_idx=0) self.char_embeddings = nn.Embedding(257, char_embed_dim, padding_idx=0)
self.symbol_embeddings = nn.Parameter(torch.FloatTensor(2, word_embed_dim)) self.symbol_embeddings = nn.Parameter(torch.FloatTensor(2, word_embed_dim))
self.eos_idx, self.unk_idx = 0, 1 self.eos_idx, self.unk_idx = 0, 1
self.char_inputs = char_inputs
self.convolutions = nn.ModuleList() self.convolutions = nn.ModuleList()
for width, out_c in filters: for width, out_c in filters:
...@@ -84,26 +90,34 @@ class CharacterTokenEmbedder(torch.nn.Module): ...@@ -84,26 +90,34 @@ class CharacterTokenEmbedder(torch.nn.Module):
def forward( def forward(
self, self,
words: torch.Tensor, input: torch.Tensor,
): ):
self.word_to_char = self.word_to_char.type_as(words) if self.char_inputs:
chars = input.view(-1, self.max_char_len)
flat_words = words.view(-1) pads = chars[:, 0].eq(CHAR_PAD_IDX)
word_embs = self._convolve(self.word_to_char[flat_words]) eos = chars[:, 0].eq(CHAR_EOS_IDX)
if eos.any():
pads = flat_words.eq(self.vocab.pad()) chars[eos] = 0
unk = None
else:
self.word_to_char = self.word_to_char.type_as(input)
flat_words = input.view(-1)
chars = self.word_to_char[flat_words]
pads = flat_words.eq(self.vocab.pad())
eos = flat_words.eq(self.vocab.eos())
unk = flat_words.eq(self.vocab.unk())
word_embs = self._convolve(chars)
if pads.any(): if pads.any():
word_embs[pads] = 0 word_embs[pads] = 0
eos = flat_words.eq(self.vocab.eos())
if eos.any(): if eos.any():
word_embs[eos] = self.symbol_embeddings[self.eos_idx] word_embs[eos] = self.symbol_embeddings[self.eos_idx]
unk = flat_words.eq(self.vocab.unk()) if unk is not None and unk.any():
if unk.any():
word_embs[unk] = self.symbol_embeddings[self.unk_idx] word_embs[unk] = self.symbol_embeddings[self.unk_idx]
return word_embs.view(words.size() + (-1,)) return word_embs.view(input.size()[:2] + (-1,))
def _convolve( def _convolve(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment