Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
1b5a498c
Commit
1b5a498c
authored
Apr 17, 2018
by
Alexei Baevski
Committed by
Myle Ott
Jun 15, 2018
Browse files
allow overwriting args for different architectures
parent
a3e4c4c3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
75 deletions
+65
-75
fairseq/models/fconv.py
fairseq/models/fconv.py
+23
-22
fairseq/models/lstm.py
fairseq/models/lstm.py
+17
-23
fairseq/models/transformer.py
fairseq/models/transformer.py
+25
-30
No files found.
fairseq/models/fconv.py
View file @
1b5a498c
...
@@ -51,6 +51,9 @@ class FConvModel(FairseqModel):
...
@@ -51,6 +51,9 @@ class FConvModel(FairseqModel):
@
classmethod
@
classmethod
def
build_model
(
cls
,
args
,
src_dict
,
dst_dict
):
def
build_model
(
cls
,
args
,
src_dict
,
dst_dict
):
# make sure that all args are properly defaulted (in case there are any new ones)
base_architecture
(
args
)
"""Build a new model instance."""
"""Build a new model instance."""
if
not
hasattr
(
args
,
'max_source_positions'
):
if
not
hasattr
(
args
,
'max_source_positions'
):
args
.
max_source_positions
=
args
.
max_positions
args
.
max_source_positions
=
args
.
max_positions
...
@@ -468,47 +471,45 @@ def base_architecture(args):
...
@@ -468,47 +471,45 @@ def base_architecture(args):
@
register_model_architecture
(
'fconv'
,
'fconv_iwslt_de_en'
)
@
register_model_architecture
(
'fconv'
,
'fconv_iwslt_de_en'
)
def
fconv_iwslt_de_en
(
args
):
def
fconv_iwslt_de_en
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
256
)
args
.
encoder_layers
=
getattr
(
args
,
'encoder_layers'
,
'[(256, 3)] * 4'
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
256
)
args
.
decoder_layers
=
getattr
(
args
,
'decoder_layers'
,
'[(256, 3)] * 3'
)
args
.
decoder_out_embed_dim
=
getattr
(
args
,
'decoder_out_embed_dim'
,
256
)
base_architecture
(
args
)
base_architecture
(
args
)
args
.
encoder_embed_dim
=
256
args
.
encoder_layers
=
'[(256, 3)] * 4'
args
.
decoder_embed_dim
=
256
args
.
decoder_layers
=
'[(256, 3)] * 3'
args
.
decoder_out_embed_dim
=
256
@
register_model_architecture
(
'fconv'
,
'fconv_wmt_en_ro'
)
@
register_model_architecture
(
'fconv'
,
'fconv_wmt_en_ro'
)
def
fconv_wmt_en_ro
(
args
):
def
fconv_wmt_en_ro
(
args
):
args
.
decoder_out_embed_dim
=
getattr
(
args
,
'decoder_out_embed_dim'
,
512
)
base_architecture
(
args
)
base_architecture
(
args
)
args
.
encoder_embed_dim
=
512
args
.
encoder_layers
=
'[(512, 3)] * 20'
args
.
decoder_embed_dim
=
512
args
.
decoder_layers
=
'[(512, 3)] * 20'
args
.
decoder_out_embed_dim
=
512
@
register_model_architecture
(
'fconv'
,
'fconv_wmt_en_de'
)
@
register_model_architecture
(
'fconv'
,
'fconv_wmt_en_de'
)
def
fconv_wmt_en_de
(
args
):
def
fconv_wmt_en_de
(
args
):
base_architecture
(
args
)
convs
=
'[(512, 3)] * 9'
# first 9 layers have 512 units
convs
=
'[(512, 3)] * 9'
# first 9 layers have 512 units
convs
+=
' + [(1024, 3)] * 4'
# next 4 layers have 1024 units
convs
+=
' + [(1024, 3)] * 4'
# next 4 layers have 1024 units
convs
+=
' + [(2048, 1)] * 2'
# final 2 layers use 1x1 convolutions
convs
+=
' + [(2048, 1)] * 2'
# final 2 layers use 1x1 convolutions
args
.
encoder_embed_dim
=
768
args
.
encoder_layers
=
convs
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
768
)
args
.
decoder_embed_dim
=
768
args
.
encoder_layers
=
getattr
(
args
,
'encoder_layers'
,
convs
)
args
.
decoder_layers
=
convs
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
768
)
args
.
decoder_out_embed_dim
=
512
args
.
decoder_layers
=
getattr
(
args
,
'decoder_layers'
,
convs
)
args
.
decoder_out_embed_dim
=
getattr
(
args
,
'decoder_out_embed_dim'
,
512
)
base_architecture
(
args
)
@
register_model_architecture
(
'fconv'
,
'fconv_wmt_en_fr'
)
@
register_model_architecture
(
'fconv'
,
'fconv_wmt_en_fr'
)
def
fconv_wmt_en_fr
(
args
):
def
fconv_wmt_en_fr
(
args
):
base_architecture
(
args
)
convs
=
'[(512, 3)] * 6'
# first 6 layers have 512 units
convs
=
'[(512, 3)] * 6'
# first 6 layers have 512 units
convs
+=
' + [(768, 3)] * 4'
# next 4 layers have 768 units
convs
+=
' + [(768, 3)] * 4'
# next 4 layers have 768 units
convs
+=
' + [(1024, 3)] * 3'
# next 3 layers have 1024 units
convs
+=
' + [(1024, 3)] * 3'
# next 3 layers have 1024 units
convs
+=
' + [(2048, 1)] * 1'
# next 1 layer uses 1x1 convolutions
convs
+=
' + [(2048, 1)] * 1'
# next 1 layer uses 1x1 convolutions
convs
+=
' + [(4096, 1)] * 1'
# final 1 layer uses 1x1 convolutions
convs
+=
' + [(4096, 1)] * 1'
# final 1 layer uses 1x1 convolutions
args
.
encoder_embed_dim
=
768
args
.
encoder_layers
=
convs
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
768
)
args
.
decoder_embed_dim
=
768
args
.
encoder_layers
=
getattr
(
args
,
'encoder_layers'
,
convs
)
args
.
decoder_layers
=
convs
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
768
)
args
.
decoder_out_embed_dim
=
512
args
.
decoder_layers
=
getattr
(
args
,
'decoder_layers'
,
convs
)
args
.
decoder_out_embed_dim
=
getattr
(
args
,
'decoder_out_embed_dim'
,
512
)
base_architecture
(
args
)
fairseq/models/lstm.py
View file @
1b5a498c
...
@@ -61,6 +61,9 @@ class LSTMModel(FairseqModel):
...
@@ -61,6 +61,9 @@ class LSTMModel(FairseqModel):
@
classmethod
@
classmethod
def
build_model
(
cls
,
args
,
src_dict
,
dst_dict
):
def
build_model
(
cls
,
args
,
src_dict
,
dst_dict
):
# make sure that all args are properly defaulted (in case there are any new ones)
base_architecture
(
args
)
"""Build a new model instance."""
"""Build a new model instance."""
if
not
hasattr
(
args
,
'encoder_embed_path'
):
if
not
hasattr
(
args
,
'encoder_embed_path'
):
args
.
encoder_embed_path
=
None
args
.
encoder_embed_path
=
None
...
@@ -452,32 +455,23 @@ def base_architecture(args):
...
@@ -452,32 +455,23 @@ def base_architecture(args):
@
register_model_architecture
(
'lstm'
,
'lstm_wiseman_iwslt_de_en'
)
@
register_model_architecture
(
'lstm'
,
'lstm_wiseman_iwslt_de_en'
)
def
lstm_wiseman_iwslt_de_en
(
args
):
def
lstm_wiseman_iwslt_de_en
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
256
)
args
.
encoder_dropout_in
=
getattr
(
args
,
'encoder_dropout_in'
,
0
)
args
.
encoder_dropout_out
=
getattr
(
args
,
'encoder_dropout_out'
,
0
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
256
)
args
.
decoder_out_embed_dim
=
getattr
(
args
,
'decoder_out_embed_dim'
,
256
)
args
.
decoder_dropout_in
=
getattr
(
args
,
'decoder_dropout_in'
,
0
)
args
.
decoder_dropout_out
=
getattr
(
args
,
'decoder_dropout_out'
,
args
.
dropout
)
base_architecture
(
args
)
base_architecture
(
args
)
args
.
encoder_embed_dim
=
256
args
.
encoder_hidden_size
=
256
args
.
encoder_layers
=
1
args
.
encoder_bidirectional
=
False
args
.
encoder_dropout_in
=
0
args
.
encoder_dropout_out
=
0
args
.
decoder_embed_dim
=
256
args
.
decoder_hidden_size
=
256
args
.
decoder_layers
=
1
args
.
decoder_out_embed_dim
=
256
args
.
decoder_attention
=
'1'
args
.
decoder_dropout_in
=
0
@
register_model_architecture
(
'lstm'
,
'lstm_luong_wmt_en_de'
)
@
register_model_architecture
(
'lstm'
,
'lstm_luong_wmt_en_de'
)
def
lstm_luong_wmt_en_de
(
args
):
def
lstm_luong_wmt_en_de
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
1000
)
args
.
encoder_layers
=
getattr
(
args
,
'encoder_layers'
,
4
)
args
.
encoder_dropout_out
=
getattr
(
args
,
'encoder_dropout_out'
,
0
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
1000
)
args
.
decoder_layers
=
getattr
(
args
,
'decoder_layers'
,
4
)
args
.
decoder_out_embed_dim
=
getattr
(
args
,
'decoder_out_embed_dim'
,
1000
)
args
.
decoder_dropout_out
=
getattr
(
args
,
'decoder_dropout_out'
,
0
)
base_architecture
(
args
)
base_architecture
(
args
)
args
.
encoder_embed_dim
=
1000
args
.
encoder_hidden_size
=
1000
args
.
encoder_layers
=
4
args
.
encoder_dropout_out
=
0
args
.
encoder_bidirectional
=
False
args
.
decoder_embed_dim
=
1000
args
.
decoder_hidden_size
=
1000
args
.
decoder_layers
=
4
args
.
decoder_out_embed_dim
=
1000
args
.
decoder_attention
=
'1'
args
.
decoder_dropout_out
=
0
fairseq/models/transformer.py
View file @
1b5a498c
...
@@ -96,6 +96,7 @@ class TransformerModel(FairseqModel):
...
@@ -96,6 +96,7 @@ class TransformerModel(FairseqModel):
class
TransformerEncoder
(
FairseqEncoder
):
class
TransformerEncoder
(
FairseqEncoder
):
"""Transformer encoder."""
"""Transformer encoder."""
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
):
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
):
super
().
__init__
(
dictionary
)
super
().
__init__
(
dictionary
)
self
.
dropout
=
args
.
dropout
self
.
dropout
=
args
.
dropout
...
@@ -155,6 +156,7 @@ class TransformerEncoder(FairseqEncoder):
...
@@ -155,6 +156,7 @@ class TransformerEncoder(FairseqEncoder):
class
TransformerDecoder
(
FairseqIncrementalDecoder
):
class
TransformerDecoder
(
FairseqIncrementalDecoder
):
"""Transformer decoder."""
"""Transformer decoder."""
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
):
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
):
super
().
__init__
(
dictionary
)
super
().
__init__
(
dictionary
)
self
.
dropout
=
args
.
dropout
self
.
dropout
=
args
.
dropout
...
@@ -250,6 +252,7 @@ class TransformerEncoderLayer(nn.Module):
...
@@ -250,6 +252,7 @@ class TransformerEncoderLayer(nn.Module):
We default to the approach in the paper, but the tensor2tensor approach can
We default to the approach in the paper, but the tensor2tensor approach can
be enabled by setting `normalize_before=True`.
be enabled by setting `normalize_before=True`.
"""
"""
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
=
args
.
encoder_embed_dim
self
.
embed_dim
=
args
.
encoder_embed_dim
...
@@ -292,6 +295,7 @@ class TransformerEncoderLayer(nn.Module):
...
@@ -292,6 +295,7 @@ class TransformerEncoderLayer(nn.Module):
class
TransformerDecoderLayer
(
nn
.
Module
):
class
TransformerDecoderLayer
(
nn
.
Module
):
"""Decoder layer block."""
"""Decoder layer block."""
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
=
args
.
decoder_embed_dim
self
.
embed_dim
=
args
.
decoder_embed_dim
...
@@ -399,56 +403,47 @@ def base_architecture(args):
...
@@ -399,56 +403,47 @@ def base_architecture(args):
@
register_model_architecture
(
'transformer'
,
'transformer_iwslt_de_en'
)
@
register_model_architecture
(
'transformer'
,
'transformer_iwslt_de_en'
)
def
transformer_iwslt_de_en
(
args
):
def
transformer_iwslt_de_en
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
256
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
'encoder_ffn_embed_dim'
,
512
)
args
.
encoder_attention_heads
=
getattr
(
args
,
'encoder_attention_heads'
,
4
)
args
.
encoder_layers
=
getattr
(
args
,
'encoder_layers'
,
3
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
256
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
'decoder_ffn_embed_dim'
,
512
)
args
.
decoder_attention_heads
=
getattr
(
args
,
'decoder_attention_heads'
,
4
)
args
.
decoder_layers
=
getattr
(
args
,
'decoder_layers'
,
3
)
base_architecture
(
args
)
base_architecture
(
args
)
args
.
encoder_embed_dim
=
256
args
.
encoder_ffn_embed_dim
=
512
args
.
encoder_layers
=
3
args
.
encoder_attention_heads
=
4
args
.
decoder_embed_dim
=
256
args
.
decoder_ffn_embed_dim
=
512
args
.
decoder_layers
=
3
args
.
decoder_attention_heads
=
4
@
register_model_architecture
(
'transformer'
,
'transformer_wmt_en_de'
)
@
register_model_architecture
(
'transformer'
,
'transformer_wmt_en_de'
)
def
transformer_wmt_en_de
(
args
):
def
transformer_wmt_en_de
(
args
):
base_architecture
(
args
)
base_architecture
(
args
)
args
.
encoder_embed_dim
=
512
args
.
encoder_ffn_embed_dim
=
2048
args
.
encoder_layers
=
6
args
.
encoder_attention_heads
=
8
args
.
decoder_embed_dim
=
512
args
.
decoder_ffn_embed_dim
=
2048
args
.
decoder_layers
=
6
args
.
decoder_attention_heads
=
8
# parameters used in the "Attention Is All You Need" paper (Vaswani, et al, 2017)
# parameters used in the "Attention Is All You Need" paper (Vaswani, et al, 2017)
@
register_model_architecture
(
'transformer'
,
'transformer_vaswani_wmt_en_de_big'
)
@
register_model_architecture
(
'transformer'
,
'transformer_vaswani_wmt_en_de_big'
)
def
transformer_vaswani_wmt_en_de_big
(
args
):
def
transformer_vaswani_wmt_en_de_big
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
1024
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
'encoder_ffn_embed_dim'
,
4096
)
args
.
encoder_attention_heads
=
getattr
(
args
,
'encoder_attention_heads'
,
16
)
args
.
encoder_normalize_before
=
getattr
(
args
,
'encoder_normalize_before'
,
False
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
1024
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
'decoder_ffn_embed_dim'
,
4096
)
args
.
decoder_attention_heads
=
getattr
(
args
,
'decoder_attention_heads'
,
16
)
args
.
dropout
=
getattr
(
args
,
'dropout'
,
0.3
)
base_architecture
(
args
)
base_architecture
(
args
)
args
.
encoder_embed_dim
=
1024
args
.
encoder_ffn_embed_dim
=
4096
args
.
encoder_layers
=
6
args
.
encoder_attention_heads
=
16
args
.
decoder_embed_dim
=
1024
args
.
decoder_ffn_embed_dim
=
4096
args
.
decoder_layers
=
6
args
.
decoder_attention_heads
=
16
args
.
dropout
=
0.3
@
register_model_architecture
(
'transformer'
,
'transformer_wmt_en_de_big'
)
@
register_model_architecture
(
'transformer'
,
'transformer_wmt_en_de_big'
)
def
transformer_wmt_en_de_big
(
args
):
def
transformer_wmt_en_de_big
(
args
):
args
.
attention_dropout
=
getattr
(
args
,
'attention_dropout'
,
0.1
)
transformer_vaswani_wmt_en_de_big
(
args
)
transformer_vaswani_wmt_en_de_big
(
args
)
args
.
attention_dropout
=
0.1
# default parameters used in tensor2tensor implementation
# default parameters used in tensor2tensor implementation
@
register_model_architecture
(
'transformer'
,
'transformer_wmt_en_de_big_t2t'
)
@
register_model_architecture
(
'transformer'
,
'transformer_wmt_en_de_big_t2t'
)
def
transformer_wmt_en_de_big_t2t
(
args
):
def
transformer_wmt_en_de_big_t2t
(
args
):
args
.
encoder_normalize_before
=
getattr
(
args
,
'encoder_normalize_before'
,
True
)
args
.
encoder_normalize_before
=
getattr
(
args
,
'decoder_normalize_before'
,
True
)
args
.
attention_dropout
=
getattr
(
args
,
'attention_dropout'
,
0.1
)
args
.
relu_dropout
=
getattr
(
args
,
'relu_dropout'
,
0.1
)
transformer_vaswani_wmt_en_de_big
(
args
)
transformer_vaswani_wmt_en_de_big
(
args
)
args
.
encoder_normalize_before
=
True
args
.
decoder_normalize_before
=
True
args
.
attention_dropout
=
0.1
args
.
relu_dropout
=
0.1
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment