Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
885e7ec9
Commit
885e7ec9
authored
Jul 28, 2018
by
Alexei Baevski
Committed by
Myle Ott
Sep 03, 2018
Browse files
character token embeddings for word level predictions
parent
616afddd
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
253 additions
and
2 deletions
+253
-2
fairseq/models/transformer.py
fairseq/models/transformer.py
+24
-2
fairseq/modules/__init__.py
fairseq/modules/__init__.py
+4
-0
fairseq/modules/character_token_embedder.py
fairseq/modules/character_token_embedder.py
+126
-0
fairseq/modules/highway.py
fairseq/modules/highway.py
+55
-0
tests/test_character_token_embedder.py
tests/test_character_token_embedder.py
+44
-0
No files found.
fairseq/models/transformer.py
View file @
885e7ec9
...
...
@@ -15,7 +15,8 @@ from fairseq import options
from
fairseq
import
utils
from
fairseq.modules
import
(
AdaptiveSoftmax
,
LearnedPositionalEmbedding
,
MultiheadAttention
,
SinusoidalPositionalEmbedding
AdaptiveSoftmax
,
CharacterTokenEmbedder
,
LearnedPositionalEmbedding
,
MultiheadAttention
,
SinusoidalPositionalEmbedding
)
from
.
import
(
...
...
@@ -161,6 +162,15 @@ class TransformerLanguageModel(FairseqLanguageModel):
help
=
'if set, disables positional embeddings (outside self attention)'
)
parser
.
add_argument
(
'--share-decoder-input-output-embed'
,
default
=
False
,
action
=
'store_true'
,
help
=
'share decoder input and output embeddings'
)
parser
.
add_argument
(
'--character-embeddings'
,
default
=
False
,
action
=
'store_true'
,
help
=
'if set, uses character embedding convolutions to produce token embeddings'
)
parser
.
add_argument
(
'--character-filters'
,
type
=
str
,
metavar
=
'LIST'
,
default
=
'[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]'
,
help
=
'size of character embeddings'
)
parser
.
add_argument
(
'--character-embedding-dim'
,
type
=
int
,
metavar
=
'N'
,
default
=
4
,
help
=
'size of character embeddings'
)
parser
.
add_argument
(
'--char-embedder-highway-layers'
,
type
=
int
,
metavar
=
'N'
,
default
=
2
,
help
=
'number of highway layers for character token embeddder'
)
@
classmethod
def
build_model
(
cls
,
args
,
task
):
...
...
@@ -174,7 +184,19 @@ class TransformerLanguageModel(FairseqLanguageModel):
if
not
hasattr
(
args
,
'max_target_positions'
):
args
.
max_target_positions
=
args
.
tokens_per_sample
embed_tokens
=
Embedding
(
len
(
task
.
dictionary
),
args
.
decoder_embed_dim
,
task
.
dictionary
.
pad
())
if
args
.
character_embeddings
:
if
not
hasattr
(
args
,
'char_embedder_highway_layers'
):
args
.
char_embedder_highway_layers
=
0
if
not
hasattr
(
args
,
'character_filters'
):
args
.
character_filters
=
'[(1, 4), (2, 8), (3, 16), (4, 32), (5, 64)]'
embed_tokens
=
CharacterTokenEmbedder
(
task
.
dictionary
,
eval
(
args
.
character_filters
),
args
.
character_embedding_dim
,
args
.
decoder_embed_dim
,
args
.
char_embedder_highway_layers
,
)
else
:
embed_tokens
=
Embedding
(
len
(
task
.
dictionary
),
args
.
decoder_embed_dim
,
task
.
dictionary
.
pad
())
decoder
=
TransformerDecoder
(
args
,
task
.
dictionary
,
embed_tokens
,
no_encoder_attn
=
True
)
return
TransformerLanguageModel
(
decoder
)
...
...
fairseq/modules/__init__.py
View file @
885e7ec9
...
...
@@ -7,9 +7,11 @@
from
.adaptive_softmax
import
AdaptiveSoftmax
from
.beamable_mm
import
BeamableMM
from
.character_token_embedder
import
CharacterTokenEmbedder
from
.conv_tbc
import
ConvTBC
from
.downsampled_multihead_attention
import
DownsampledMultiHeadAttention
from
.grad_multiply
import
GradMultiply
from
.highway
import
Highway
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.linearized_convolution
import
LinearizedConvolution
from
.multihead_attention
import
MultiheadAttention
...
...
@@ -19,9 +21,11 @@ from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
__all__
=
[
'AdaptiveSoftmax'
,
'BeamableMM'
,
'CharacterTokenEmbedder'
,
'ConvTBC'
,
'DownsampledMultiHeadAttention'
,
'GradMultiply'
,
'Highway'
,
'LearnedPositionalEmbedding'
,
'LinearizedConvolution'
,
'MultiheadAttention'
,
...
...
fairseq/modules/character_token_embedder.py
0 → 100644
View file @
885e7ec9
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch.nn.utils.rnn
import
pad_sequence
from
typing
import
List
,
Tuple
from
.highway
import
Highway
from
fairseq.data
import
Dictionary
class
CharacterTokenEmbedder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
vocab
:
Dictionary
,
filters
:
List
[
Tuple
[
int
,
int
]],
char_embed_dim
:
int
,
word_embed_dim
:
int
,
highway_layers
:
int
,
max_char_len
:
int
=
50
,
):
super
(
CharacterTokenEmbedder
,
self
).
__init__
()
self
.
embedding_dim
=
word_embed_dim
self
.
char_embeddings
=
nn
.
Embedding
(
257
,
char_embed_dim
,
padding_idx
=
0
)
self
.
symbol_embeddings
=
nn
.
Parameter
(
torch
.
FloatTensor
(
2
,
word_embed_dim
))
self
.
eos_idx
,
self
.
unk_idx
=
0
,
1
self
.
convolutions
=
nn
.
ModuleList
()
for
width
,
out_c
in
filters
:
self
.
convolutions
.
append
(
nn
.
Conv1d
(
char_embed_dim
,
out_c
,
kernel_size
=
width
)
)
final_dim
=
sum
(
f
[
1
]
for
f
in
filters
)
self
.
highway
=
Highway
(
final_dim
,
highway_layers
)
self
.
projection
=
nn
.
Linear
(
final_dim
,
word_embed_dim
)
self
.
set_vocab
(
vocab
,
max_char_len
)
self
.
reset_parameters
()
def
set_vocab
(
self
,
vocab
,
max_char_len
):
word_to_char
=
torch
.
LongTensor
(
len
(
vocab
),
max_char_len
)
truncated
=
0
for
i
in
range
(
len
(
vocab
)):
if
i
<
vocab
.
nspecial
:
char_idxs
=
[
0
]
*
max_char_len
else
:
chars
=
vocab
[
i
].
encode
()
# +1 for padding
char_idxs
=
[
c
+
1
for
c
in
chars
]
+
[
0
]
*
(
max_char_len
-
len
(
chars
))
if
len
(
char_idxs
)
>
max_char_len
:
truncated
+=
1
char_idxs
=
char_idxs
[:
max_char_len
]
word_to_char
[
i
]
=
torch
.
LongTensor
(
char_idxs
)
if
truncated
>
0
:
print
(
'Truncated {} words longer than {} characters'
.
format
(
truncated
,
max_char_len
))
self
.
vocab
=
vocab
self
.
word_to_char
=
word_to_char
@
property
def
padding_idx
(
self
):
return
self
.
vocab
.
pad
()
def
reset_parameters
(
self
):
nn
.
init
.
xavier_normal_
(
self
.
char_embeddings
.
weight
)
nn
.
init
.
xavier_normal_
(
self
.
symbol_embeddings
)
nn
.
init
.
xavier_normal_
(
self
.
projection
.
weight
)
nn
.
init
.
constant_
(
self
.
char_embeddings
.
weight
[
self
.
char_embeddings
.
padding_idx
],
0.
)
nn
.
init
.
constant_
(
self
.
projection
.
bias
,
0.
)
def
forward
(
self
,
words
:
torch
.
Tensor
,
):
self
.
word_to_char
=
self
.
word_to_char
.
type_as
(
words
)
flat_words
=
words
.
view
(
-
1
)
word_embs
=
self
.
_convolve
(
self
.
word_to_char
[
flat_words
])
pads
=
flat_words
.
eq
(
self
.
vocab
.
pad
())
if
pads
.
any
():
word_embs
[
pads
]
=
0
eos
=
flat_words
.
eq
(
self
.
vocab
.
eos
())
if
eos
.
any
():
word_embs
[
eos
]
=
self
.
symbol_embeddings
[
self
.
eos_idx
]
unk
=
flat_words
.
eq
(
self
.
vocab
.
unk
())
if
unk
.
any
():
word_embs
[
unk
]
=
self
.
symbol_embeddings
[
self
.
unk_idx
]
return
word_embs
.
view
(
words
.
size
()
+
(
-
1
,))
def
_convolve
(
self
,
char_idxs
:
torch
.
Tensor
,
):
char_embs
=
self
.
char_embeddings
(
char_idxs
)
char_embs
=
char_embs
.
transpose
(
1
,
2
)
# BTC -> BCT
conv_result
=
[]
for
i
,
conv
in
enumerate
(
self
.
convolutions
):
x
=
conv
(
char_embs
)
x
,
_
=
torch
.
max
(
x
,
-
1
)
x
=
F
.
relu
(
x
)
conv_result
.
append
(
x
)
conv_result
=
torch
.
cat
(
conv_result
,
dim
=-
1
)
conv_result
=
self
.
highway
(
conv_result
)
return
self
.
projection
(
conv_result
)
fairseq/modules/highway.py
0 → 100644
View file @
885e7ec9
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
class
Highway
(
torch
.
nn
.
Module
):
"""
A `Highway layer <https://arxiv.org/abs/1505.00387>
Adopted from the AllenNLP implementation
"""
def
__init__
(
self
,
input_dim
:
int
,
num_layers
:
int
=
1
):
super
(
Highway
,
self
).
__init__
()
self
.
input_dim
=
input_dim
self
.
layers
=
nn
.
ModuleList
([
nn
.
Linear
(
input_dim
,
input_dim
*
2
)
for
_
in
range
(
num_layers
)])
self
.
activation
=
nn
.
ReLU
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
for
layer
in
self
.
layers
:
# As per comment in AllenNLP:
# We should bias the highway layer to just carry its input forward. We do that by
# setting the bias on `B(x)` to be positive, because that means `g` will be biased to
# be high, so we will carry the input forward. The bias on `B(x)` is the second half
# of the bias vector in each Linear layer.
nn
.
init
.
constant_
(
layer
.
bias
[
self
.
input_dim
:],
1
)
nn
.
init
.
constant_
(
layer
.
bias
[:
self
.
input_dim
],
0
)
nn
.
init
.
xavier_normal_
(
layer
.
weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
for
layer
in
self
.
layers
:
projection
=
layer
(
x
)
proj_x
,
gate
=
projection
.
chunk
(
2
,
dim
=-
1
)
proj_x
=
self
.
activation
(
proj_x
)
gate
=
F
.
sigmoid
(
gate
)
x
=
gate
*
x
+
(
1
-
gate
)
*
proj_x
return
x
tests/test_character_token_embedder.py
0 → 100644
View file @
885e7ec9
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch
import
unittest
from
fairseq.data
import
Dictionary
from
fairseq.modules
import
CharacterTokenEmbedder
class
TestCharacterTokenEmbedder
(
unittest
.
TestCase
):
def
test_character_token_embedder
(
self
):
vocab
=
Dictionary
()
vocab
.
add_symbol
(
'hello'
)
vocab
.
add_symbol
(
'there'
)
embedder
=
CharacterTokenEmbedder
(
vocab
,
[(
2
,
16
),
(
4
,
32
),
(
8
,
64
),
(
16
,
2
)],
64
,
5
)
test_sents
=
[[
'hello'
,
'unk'
,
'there'
],
[
'there'
],
[
'hello'
,
'there'
]]
max_len
=
max
(
len
(
s
)
for
s
in
test_sents
)
input
=
torch
.
LongTensor
(
len
(
test_sents
),
max_len
+
2
)
for
i
in
range
(
len
(
test_sents
)):
input
[
i
][
0
]
=
vocab
.
eos
()
for
j
in
range
(
len
(
test_sents
[
i
])):
input
[
i
][
j
+
1
]
=
vocab
.
index
(
test_sents
[
i
][
j
])
input
[
i
][
j
+
2
]
=
vocab
.
eos
()
embs
=
embedder
(
input
)
assert
embs
.
size
()
==
(
len
(
test_sents
),
max_len
+
2
,
5
)
assert
embs
[
0
][
0
].
equal
(
embs
[
1
][
0
])
assert
embs
[
0
][
0
].
equal
(
embs
[
0
][
-
1
])
assert
embs
[
0
][
1
].
equal
(
embs
[
2
][
1
])
assert
embs
[
0
][
3
].
equal
(
embs
[
1
][
1
])
embs
.
sum
().
backward
()
assert
embedder
.
char_embeddings
.
weight
.
grad
is
not
None
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment