Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
c21a6e29
Commit
c21a6e29
authored
Jan 06, 2018
by
Myle Ott
Browse files
Move positional embeddings into LearnedPositionalEmbedding module
parent
185a0df5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
82 additions
and
46 deletions
+82
-46
fairseq/models/fconv.py
fairseq/models/fconv.py
+22
-46
fairseq/modules/__init__.py
fairseq/modules/__init__.py
+2
-0
fairseq/modules/learned_positional_embedding.py
fairseq/modules/learned_positional_embedding.py
+58
-0
No files found.
fairseq/models/fconv.py
View file @
c21a6e29
...
@@ -13,26 +13,11 @@ import torch.nn as nn
...
@@ -13,26 +13,11 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
fairseq.data
import
LanguagePairDataset
from
fairseq.data
import
LanguagePairDataset
from
fairseq.modules
import
BeamableMM
,
GradMultiply
,
LinearizedConvolution
from
fairseq.modules
import
BeamableMM
,
GradMultiply
,
LearnedPositionalEmbedding
,
LinearizedConvolution
from
.
import
FairseqEncoder
,
FairseqIncrementalDecoder
,
FairseqModel
from
.
import
FairseqEncoder
,
FairseqIncrementalDecoder
,
FairseqModel
def
make_positions
(
tokens
,
padding_idx
,
left_pad
,
offset
=
0
):
seqlen
=
tokens
.
size
(
1
)
if
not
hasattr
(
make_positions
,
'range'
):
make_positions
.
range
=
tokens
.
new
()
if
make_positions
.
range
.
numel
()
<
offset
+
seqlen
:
# offset positions by the padding index
torch
.
arange
(
padding_idx
+
1
,
padding_idx
+
1
+
offset
+
seqlen
,
out
=
make_positions
.
range
)
mask
=
tokens
.
ne
(
padding_idx
)
positions
=
make_positions
.
range
[
offset
:
offset
+
seqlen
].
expand_as
(
tokens
)
if
left_pad
:
positions
=
positions
-
mask
.
size
(
1
)
+
mask
.
long
().
sum
(
dim
=
1
).
unsqueeze
(
1
)
return
tokens
.
clone
().
masked_scatter_
(
mask
,
positions
[
mask
])
class
FConvModel
(
FairseqModel
):
class
FConvModel
(
FairseqModel
):
def
__init__
(
self
,
encoder
,
decoder
):
def
__init__
(
self
,
encoder
,
decoder
):
super
().
__init__
(
encoder
,
decoder
)
super
().
__init__
(
encoder
,
decoder
)
...
@@ -51,7 +36,8 @@ class FConvEncoder(FairseqEncoder):
...
@@ -51,7 +36,8 @@ class FConvEncoder(FairseqEncoder):
num_embeddings
=
len
(
dictionary
)
num_embeddings
=
len
(
dictionary
)
padding_idx
=
dictionary
.
pad
()
padding_idx
=
dictionary
.
pad
()
self
.
embed_tokens
=
Embedding
(
num_embeddings
,
embed_dim
,
padding_idx
)
self
.
embed_tokens
=
Embedding
(
num_embeddings
,
embed_dim
,
padding_idx
)
self
.
embed_positions
=
Embedding
(
max_positions
,
embed_dim
,
padding_idx
)
self
.
embed_positions
=
PositionalEmbedding
(
max_positions
,
embed_dim
,
padding_idx
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_SOURCE
)
in_channels
=
convolutions
[
0
][
0
]
in_channels
=
convolutions
[
0
][
0
]
self
.
fc1
=
Linear
(
embed_dim
,
in_channels
,
dropout
=
dropout
)
self
.
fc1
=
Linear
(
embed_dim
,
in_channels
,
dropout
=
dropout
)
...
@@ -68,11 +54,8 @@ class FConvEncoder(FairseqEncoder):
...
@@ -68,11 +54,8 @@ class FConvEncoder(FairseqEncoder):
self
.
fc2
=
Linear
(
in_channels
,
embed_dim
)
self
.
fc2
=
Linear
(
in_channels
,
embed_dim
)
def
forward
(
self
,
src_tokens
):
def
forward
(
self
,
src_tokens
):
positions
=
Variable
(
make_positions
(
src_tokens
.
data
,
self
.
dictionary
.
pad
(),
left_pad
=
LanguagePairDataset
.
LEFT_PAD_SOURCE
))
# embed tokens and positions
# embed tokens and positions
x
=
self
.
embed_tokens
(
src_tokens
)
+
self
.
embed_positions
(
positio
ns
)
x
=
self
.
embed_tokens
(
src_tokens
)
+
self
.
embed_positions
(
src_toke
ns
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
input_embedding
=
x
input_embedding
=
x
...
@@ -106,7 +89,7 @@ class FConvEncoder(FairseqEncoder):
...
@@ -106,7 +89,7 @@ class FConvEncoder(FairseqEncoder):
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
"""Maximum input length supported by the encoder."""
return
self
.
embed_positions
.
num_embeddings
-
self
.
dictionary
.
pad
()
-
1
return
self
.
embed_positions
.
max_positions
()
class
AttentionLayer
(
nn
.
Module
):
class
AttentionLayer
(
nn
.
Module
):
...
@@ -170,7 +153,8 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -170,7 +153,8 @@ class FConvDecoder(FairseqIncrementalDecoder):
num_embeddings
=
len
(
dictionary
)
num_embeddings
=
len
(
dictionary
)
padding_idx
=
dictionary
.
pad
()
padding_idx
=
dictionary
.
pad
()
self
.
embed_tokens
=
Embedding
(
num_embeddings
,
embed_dim
,
padding_idx
)
self
.
embed_tokens
=
Embedding
(
num_embeddings
,
embed_dim
,
padding_idx
)
self
.
embed_positions
=
Embedding
(
max_positions
,
embed_dim
,
padding_idx
)
self
.
embed_positions
=
PositionalEmbedding
(
max_positions
,
embed_dim
,
padding_idx
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
)
self
.
fc1
=
Linear
(
embed_dim
,
in_channels
,
dropout
=
dropout
)
self
.
fc1
=
Linear
(
embed_dim
,
in_channels
,
dropout
=
dropout
)
self
.
projections
=
nn
.
ModuleList
()
self
.
projections
=
nn
.
ModuleList
()
...
@@ -190,32 +174,18 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -190,32 +174,18 @@ class FConvDecoder(FairseqIncrementalDecoder):
self
.
fc3
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout
)
self
.
fc3
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout
)
def
forward
(
self
,
input_tokens
,
encoder_out
):
def
forward
(
self
,
input_tokens
,
encoder_out
):
if
self
.
_is_incremental_eval
:
return
self
.
incremental_forward
(
input_tokens
,
encoder_out
)
else
:
return
self
.
batch_forward
(
input_tokens
,
encoder_out
)
def
batch_forward
(
self
,
input_tokens
,
encoder_out
):
"""Forward pass for decoding multiple time steps in batch mode."""
positions
=
Variable
(
make_positions
(
input_tokens
.
data
,
self
.
dictionary
.
pad
(),
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
))
return
self
.
_forward
(
input_tokens
,
positions
,
encoder_out
)
def
incremental_forward
(
self
,
input_tokens
,
encoder_out
):
"""Forward pass for one time step."""
# positions is the same for every token when decoding a single step
positions
=
Variable
(
input_tokens
.
data
.
new
(
1
,
1
).
fill_
(
self
.
dictionary
.
pad
()
+
input_tokens
.
size
(
1
)))
# keep only the last token for incremental forward pass
return
self
.
_forward
(
input_tokens
[:,
-
1
:],
positions
,
encoder_out
)
def
_forward
(
self
,
input_tokens
,
positions
,
encoder_out
):
# split and transpose encoder outputs
# split and transpose encoder outputs
encoder_a
,
encoder_b
=
self
.
_split_encoder_out
(
encoder_out
)
encoder_a
,
encoder_b
=
self
.
_split_encoder_out
(
encoder_out
)
# embed positions
positions
=
self
.
embed_positions
(
input_tokens
)
if
self
.
_is_incremental_eval
:
# keep only the last token for incremental forward pass
input_tokens
=
input_tokens
[:,
-
1
:]
# embed tokens and positions
# embed tokens and positions
x
=
self
.
embed_tokens
(
input_tokens
)
+
self
.
embed_positions
(
positions
)
x
=
self
.
embed_tokens
(
input_tokens
)
+
positions
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
target_embedding
=
x
target_embedding
=
x
...
@@ -268,7 +238,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -268,7 +238,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum output length supported by the decoder."""
"""Maximum output length supported by the decoder."""
return
self
.
embed_positions
.
num_embeddings
-
self
.
dictionary
.
pad
()
-
1
return
self
.
embed_positions
.
max_positions
()
def
upgrade_state_dict
(
self
,
state_dict
):
def
upgrade_state_dict
(
self
,
state_dict
):
if
state_dict
.
get
(
'decoder.version'
,
torch
.
Tensor
([
1
]))[
0
]
<
2
:
if
state_dict
.
get
(
'decoder.version'
,
torch
.
Tensor
([
1
]))[
0
]
<
2
:
...
@@ -308,6 +278,12 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
...
@@ -308,6 +278,12 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
return
m
return
m
def
PositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
):
m
=
LearnedPositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
)
m
.
weight
.
data
.
normal_
(
0
,
0.1
)
return
m
def
Linear
(
in_features
,
out_features
,
dropout
=
0
):
def
Linear
(
in_features
,
out_features
,
dropout
=
0
):
"""Weight-normalized Linear layer (input: N x T x C)"""
"""Weight-normalized Linear layer (input: N x T x C)"""
m
=
nn
.
Linear
(
in_features
,
out_features
)
m
=
nn
.
Linear
(
in_features
,
out_features
)
...
...
fairseq/modules/__init__.py
View file @
c21a6e29
...
@@ -9,11 +9,13 @@
...
@@ -9,11 +9,13 @@
from
.beamable_mm
import
BeamableMM
from
.beamable_mm
import
BeamableMM
from
.conv_tbc
import
ConvTBC
from
.conv_tbc
import
ConvTBC
from
.grad_multiply
import
GradMultiply
from
.grad_multiply
import
GradMultiply
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.linearized_convolution
import
LinearizedConvolution
from
.linearized_convolution
import
LinearizedConvolution
__all__
=
[
__all__
=
[
'BeamableMM'
,
'BeamableMM'
,
'ConvTBC'
,
'ConvTBC'
,
'GradMultiply'
,
'GradMultiply'
,
'LearnedPositionalEmbedding'
,
'LinearizedConvolution'
,
'LinearizedConvolution'
,
]
]
fairseq/modules/learned_positional_embedding.py
0 → 100644
View file @
c21a6e29
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import
torch
from
torch.autograd
import
Variable
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
LearnedPositionalEmbedding
(
nn
.
Embedding
):
"""This module learns positional embeddings up to a fixed maximum size.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
):
super
().
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
)
self
.
left_pad
=
left_pad
self
.
register_buffer
(
'range_buf'
,
None
)
self
.
_is_incremental_eval
=
False
def
incremental_eval
(
self
,
mode
=
True
):
self
.
_is_incremental_eval
=
mode
def
forward
(
self
,
input
):
"""Input is expected to be of size [bsz x seqlen]."""
if
self
.
_is_incremental_eval
:
# positions is the same for every token when decoding a single step
positions
=
Variable
(
input
.
data
.
new
(
1
,
1
).
fill_
(
self
.
padding_idx
+
input
.
size
(
1
)))
else
:
positions
=
Variable
(
self
.
make_positions
(
input
.
data
))
return
super
().
forward
(
positions
)
def
max_positions
(
self
):
"""Maximum number of supported positions."""
return
self
.
num_embeddings
-
self
.
padding_idx
-
1
def
make_positions
(
self
,
input
):
"""Replace non-padding symbols with their position numbers."""
if
self
.
range_buf
is
None
:
self
.
range_buf
=
input
.
new
()
seqlen
=
input
.
size
(
1
)
if
self
.
range_buf
.
numel
()
<
seqlen
:
# offset positions by the padding index
torch
.
arange
(
self
.
padding_idx
+
1
,
self
.
padding_idx
+
1
+
seqlen
,
out
=
self
.
range_buf
)
mask
=
input
.
ne
(
self
.
padding_idx
)
positions
=
self
.
range_buf
[:
seqlen
].
expand_as
(
input
)
if
self
.
left_pad
:
positions
=
positions
-
mask
.
size
(
1
)
+
mask
.
long
().
sum
(
dim
=
1
).
unsqueeze
(
1
)
return
input
.
clone
().
masked_scatter_
(
mask
,
positions
[
mask
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment