Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
f472d141
Commit
f472d141
authored
Jul 03, 2018
by
Stephen Roller
Committed by
Myle Ott
Jul 25, 2018
Browse files
Support tied embeddings in LSTM encoder/decoder
parent
a7d0bd0e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
6 deletions
+50
-6
fairseq/models/lstm.py
fairseq/models/lstm.py
+50
-6
No files found.
fairseq/models/lstm.py
View file @
f472d141
...
...
@@ -59,6 +59,12 @@ class LSTMModel(FairseqModel):
help
=
'dropout probability for decoder input embedding'
)
parser
.
add_argument
(
'--decoder-dropout-out'
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability for decoder output'
)
parser
.
add_argument
(
'--share-decoder-input-output-embed'
,
default
=
False
,
action
=
'store_true'
,
help
=
'share decoder input and output embeddings'
)
parser
.
add_argument
(
'--share-all-embeddings'
,
default
=
False
,
action
=
'store_true'
,
help
=
'share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)'
)
@
classmethod
def
build_model
(
cls
,
args
,
task
):
...
...
@@ -78,10 +84,39 @@ class LSTMModel(FairseqModel):
if
args
.
encoder_embed_path
:
pretrained_encoder_embed
=
load_pretrained_embedding_from_file
(
args
.
encoder_embed_path
,
task
.
source_dictionary
,
args
.
encoder_embed_dim
)
pretrained_decoder_embed
=
None
if
args
.
decoder_embed_path
:
pretrained_decoder_embed
=
load_pretrained_embedding_from_file
(
args
.
decoder_embed_path
,
task
.
target_dictionary
,
args
.
decoder_embed_dim
)
if
args
.
share_all_embeddings
:
# double check all parameters combinations are valid
if
task
.
source_dictionary
!=
task
.
target_dictionary
:
raise
RuntimeError
(
'--share-all-embeddings requires a joint dictionary'
)
if
args
.
decoder_embed_path
and
(
args
.
decoder_embed_path
!=
args
.
encoder_embed_path
):
raise
RuntimeError
(
'--share-all-embed not compatible with --decoder-embed-path'
)
if
args
.
encoder_embed_dim
!=
args
.
decoder_embed_dim
:
raise
RuntimeError
(
'--share-all-embeddings requires --encoder-embed-dim to '
'match --decoder-embed-dim'
)
pretrained_decoder_embed
=
pretrained_encoder_embed
args
.
share_decoder_input_output_embed
=
True
else
:
# separate decoder input embeddings
pretrained_decoder_embed
=
None
if
args
.
decoder_embed_path
:
pretrained_decoder_embed
=
load_pretrained_embedding_from_file
(
args
.
decoder_embed_path
,
task
.
target_dictionary
,
args
.
decoder_embed_dim
)
# one last double check of parameter combinations
if
args
.
share_decoder_input_output_embed
and
(
args
.
decoder_embed_dim
!=
args
.
decoder_out_embed_dim
):
raise
RuntimeError
(
'--share-decoder-input-output-embeddings requires '
'--decoder-embed-dim to match --decoder-out-embed-dim'
)
encoder
=
LSTMEncoder
(
dictionary
=
task
.
source_dictionary
,
...
...
@@ -105,6 +140,7 @@ class LSTMModel(FairseqModel):
encoder_embed_dim
=
args
.
encoder_embed_dim
,
encoder_output_units
=
encoder
.
output_units
,
pretrained_embed
=
pretrained_decoder_embed
,
share_input_output_embed
=
args
.
share_decoder_input_output_embed
,
)
return
cls
(
encoder
,
decoder
)
...
...
@@ -251,11 +287,13 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self
,
dictionary
,
embed_dim
=
512
,
hidden_size
=
512
,
out_embed_dim
=
512
,
num_layers
=
1
,
dropout_in
=
0.1
,
dropout_out
=
0.1
,
attention
=
True
,
encoder_embed_dim
=
512
,
encoder_output_units
=
512
,
pretrained_embed
=
None
,
share_input_output_embed
=
False
,
):
super
().
__init__
(
dictionary
)
self
.
dropout_in
=
dropout_in
self
.
dropout_out
=
dropout_out
self
.
hidden_size
=
hidden_size
self
.
share_input_output_embed
=
share_input_output_embed
num_embeddings
=
len
(
dictionary
)
padding_idx
=
dictionary
.
pad
()
...
...
@@ -279,7 +317,8 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self
.
attention
=
AttentionLayer
(
encoder_output_units
,
hidden_size
)
if
attention
else
None
if
hidden_size
!=
out_embed_dim
:
self
.
additional_fc
=
Linear
(
hidden_size
,
out_embed_dim
)
self
.
fc_out
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout_out
)
if
not
self
.
share_input_output_embed
:
self
.
fc_out
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout_out
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out_dict
,
incremental_state
=
None
):
encoder_out
=
encoder_out_dict
[
'encoder_out'
]
...
...
@@ -358,7 +397,10 @@ class LSTMDecoder(FairseqIncrementalDecoder):
if
hasattr
(
self
,
'additional_fc'
):
x
=
self
.
additional_fc
(
x
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout_out
,
training
=
self
.
training
)
x
=
self
.
fc_out
(
x
)
if
self
.
share_input_output_embed
:
x
=
F
.
linear
(
x
,
self
.
embed_tokens
.
weight
)
else
:
x
=
self
.
fc_out
(
x
)
return
x
,
attn_scores
...
...
@@ -431,6 +473,8 @@ def base_architecture(args):
args
.
decoder_attention
=
getattr
(
args
,
'decoder_attention'
,
'1'
)
args
.
decoder_dropout_in
=
getattr
(
args
,
'decoder_dropout_in'
,
args
.
dropout
)
args
.
decoder_dropout_out
=
getattr
(
args
,
'decoder_dropout_out'
,
args
.
dropout
)
args
.
share_decoder_input_output_embed
=
getattr
(
args
,
'share_decoder_input_output_embed'
,
False
)
args
.
share_all_embeddings
=
getattr
(
args
,
'share_all_embeddings'
,
False
)
@
register_model_architecture
(
'lstm'
,
'lstm_wiseman_iwslt_de_en'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment