Commit 1b5a498c authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

allow overwriting args for different architectures

parent a3e4c4c3
...@@ -51,6 +51,9 @@ class FConvModel(FairseqModel): ...@@ -51,6 +51,9 @@ class FConvModel(FairseqModel):
@classmethod @classmethod
def build_model(cls, args, src_dict, dst_dict): def build_model(cls, args, src_dict, dst_dict):
# make sure that all args are properly defaulted (in case there are any new ones)
base_architecture(args)
"""Build a new model instance.""" """Build a new model instance."""
if not hasattr(args, 'max_source_positions'): if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.max_positions args.max_source_positions = args.max_positions
...@@ -468,47 +471,45 @@ def base_architecture(args): ...@@ -468,47 +471,45 @@ def base_architecture(args):
@register_model_architecture('fconv', 'fconv_iwslt_de_en') @register_model_architecture('fconv', 'fconv_iwslt_de_en')
def fconv_iwslt_de_en(args): def fconv_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
args.encoder_layers = getattr(args, 'encoder_layers', '[(256, 3)] * 4')
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
args.decoder_layers = getattr(args, 'decoder_layers', '[(256, 3)] * 3')
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
base_architecture(args) base_architecture(args)
args.encoder_embed_dim = 256
args.encoder_layers = '[(256, 3)] * 4'
args.decoder_embed_dim = 256
args.decoder_layers = '[(256, 3)] * 3'
args.decoder_out_embed_dim = 256
@register_model_architecture('fconv', 'fconv_wmt_en_ro') @register_model_architecture('fconv', 'fconv_wmt_en_ro')
def fconv_wmt_en_ro(args): def fconv_wmt_en_ro(args):
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512)
base_architecture(args) base_architecture(args)
args.encoder_embed_dim = 512
args.encoder_layers = '[(512, 3)] * 20'
args.decoder_embed_dim = 512
args.decoder_layers = '[(512, 3)] * 20'
args.decoder_out_embed_dim = 512
@register_model_architecture('fconv', 'fconv_wmt_en_de') @register_model_architecture('fconv', 'fconv_wmt_en_de')
def fconv_wmt_en_de(args): def fconv_wmt_en_de(args):
base_architecture(args)
convs = '[(512, 3)] * 9' # first 9 layers have 512 units convs = '[(512, 3)] * 9' # first 9 layers have 512 units
convs += ' + [(1024, 3)] * 4' # next 4 layers have 1024 units convs += ' + [(1024, 3)] * 4' # next 4 layers have 1024 units
convs += ' + [(2048, 1)] * 2' # final 2 layers use 1x1 convolutions convs += ' + [(2048, 1)] * 2' # final 2 layers use 1x1 convolutions
args.encoder_embed_dim = 768
args.encoder_layers = convs args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
args.decoder_embed_dim = 768 args.encoder_layers = getattr(args, 'encoder_layers', convs)
args.decoder_layers = convs args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 768)
args.decoder_out_embed_dim = 512 args.decoder_layers = getattr(args, 'decoder_layers', convs)
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512)
base_architecture(args)
@register_model_architecture('fconv', 'fconv_wmt_en_fr') @register_model_architecture('fconv', 'fconv_wmt_en_fr')
def fconv_wmt_en_fr(args): def fconv_wmt_en_fr(args):
base_architecture(args)
convs = '[(512, 3)] * 6' # first 6 layers have 512 units convs = '[(512, 3)] * 6' # first 6 layers have 512 units
convs += ' + [(768, 3)] * 4' # next 4 layers have 768 units convs += ' + [(768, 3)] * 4' # next 4 layers have 768 units
convs += ' + [(1024, 3)] * 3' # next 3 layers have 1024 units convs += ' + [(1024, 3)] * 3' # next 3 layers have 1024 units
convs += ' + [(2048, 1)] * 1' # next 1 layer uses 1x1 convolutions convs += ' + [(2048, 1)] * 1' # next 1 layer uses 1x1 convolutions
convs += ' + [(4096, 1)] * 1' # final 1 layer uses 1x1 convolutions convs += ' + [(4096, 1)] * 1' # final 1 layer uses 1x1 convolutions
args.encoder_embed_dim = 768
args.encoder_layers = convs args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
args.decoder_embed_dim = 768 args.encoder_layers = getattr(args, 'encoder_layers', convs)
args.decoder_layers = convs args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 768)
args.decoder_out_embed_dim = 512 args.decoder_layers = getattr(args, 'decoder_layers', convs)
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512)
base_architecture(args)
...@@ -61,6 +61,9 @@ class LSTMModel(FairseqModel): ...@@ -61,6 +61,9 @@ class LSTMModel(FairseqModel):
@classmethod @classmethod
def build_model(cls, args, src_dict, dst_dict): def build_model(cls, args, src_dict, dst_dict):
# make sure that all args are properly defaulted (in case there are any new ones)
base_architecture(args)
"""Build a new model instance.""" """Build a new model instance."""
if not hasattr(args, 'encoder_embed_path'): if not hasattr(args, 'encoder_embed_path'):
args.encoder_embed_path = None args.encoder_embed_path = None
...@@ -452,32 +455,23 @@ def base_architecture(args): ...@@ -452,32 +455,23 @@ def base_architecture(args):
@register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en') @register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en')
def lstm_wiseman_iwslt_de_en(args): def lstm_wiseman_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', 0)
args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', 0)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', 0)
args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
base_architecture(args) base_architecture(args)
args.encoder_embed_dim = 256
args.encoder_hidden_size = 256
args.encoder_layers = 1
args.encoder_bidirectional = False
args.encoder_dropout_in = 0
args.encoder_dropout_out = 0
args.decoder_embed_dim = 256
args.decoder_hidden_size = 256
args.decoder_layers = 1
args.decoder_out_embed_dim = 256
args.decoder_attention = '1'
args.decoder_dropout_in = 0
@register_model_architecture('lstm', 'lstm_luong_wmt_en_de') @register_model_architecture('lstm', 'lstm_luong_wmt_en_de')
def lstm_luong_wmt_en_de(args): def lstm_luong_wmt_en_de(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000)
args.encoder_layers = getattr(args, 'encoder_layers', 4)
args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', 0)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1000)
args.decoder_layers = getattr(args, 'decoder_layers', 4)
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1000)
args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', 0)
base_architecture(args) base_architecture(args)
args.encoder_embed_dim = 1000
args.encoder_hidden_size = 1000
args.encoder_layers = 4
args.encoder_dropout_out = 0
args.encoder_bidirectional = False
args.decoder_embed_dim = 1000
args.decoder_hidden_size = 1000
args.decoder_layers = 4
args.decoder_out_embed_dim = 1000
args.decoder_attention = '1'
args.decoder_dropout_out = 0
...@@ -96,6 +96,7 @@ class TransformerModel(FairseqModel): ...@@ -96,6 +96,7 @@ class TransformerModel(FairseqModel):
class TransformerEncoder(FairseqEncoder): class TransformerEncoder(FairseqEncoder):
"""Transformer encoder.""" """Transformer encoder."""
def __init__(self, args, dictionary, embed_tokens): def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout = args.dropout self.dropout = args.dropout
...@@ -155,6 +156,7 @@ class TransformerEncoder(FairseqEncoder): ...@@ -155,6 +156,7 @@ class TransformerEncoder(FairseqEncoder):
class TransformerDecoder(FairseqIncrementalDecoder): class TransformerDecoder(FairseqIncrementalDecoder):
"""Transformer decoder.""" """Transformer decoder."""
def __init__(self, args, dictionary, embed_tokens): def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout = args.dropout self.dropout = args.dropout
...@@ -250,6 +252,7 @@ class TransformerEncoderLayer(nn.Module): ...@@ -250,6 +252,7 @@ class TransformerEncoderLayer(nn.Module):
We default to the approach in the paper, but the tensor2tensor approach can We default to the approach in the paper, but the tensor2tensor approach can
be enabled by setting `normalize_before=True`. be enabled by setting `normalize_before=True`.
""" """
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
self.embed_dim = args.encoder_embed_dim self.embed_dim = args.encoder_embed_dim
...@@ -292,6 +295,7 @@ class TransformerEncoderLayer(nn.Module): ...@@ -292,6 +295,7 @@ class TransformerEncoderLayer(nn.Module):
class TransformerDecoderLayer(nn.Module): class TransformerDecoderLayer(nn.Module):
"""Decoder layer block.""" """Decoder layer block."""
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
self.embed_dim = args.decoder_embed_dim self.embed_dim = args.decoder_embed_dim
...@@ -399,56 +403,47 @@ def base_architecture(args): ...@@ -399,56 +403,47 @@ def base_architecture(args):
@register_model_architecture('transformer', 'transformer_iwslt_de_en') @register_model_architecture('transformer', 'transformer_iwslt_de_en')
def transformer_iwslt_de_en(args): def transformer_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 512)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
args.encoder_layers = getattr(args, 'encoder_layers', 3)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 512)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
args.decoder_layers = getattr(args, 'decoder_layers', 3)
base_architecture(args) base_architecture(args)
args.encoder_embed_dim = 256
args.encoder_ffn_embed_dim = 512
args.encoder_layers = 3
args.encoder_attention_heads = 4
args.decoder_embed_dim = 256
args.decoder_ffn_embed_dim = 512
args.decoder_layers = 3
args.decoder_attention_heads = 4
@register_model_architecture('transformer', 'transformer_wmt_en_de') @register_model_architecture('transformer', 'transformer_wmt_en_de')
def transformer_wmt_en_de(args): def transformer_wmt_en_de(args):
base_architecture(args) base_architecture(args)
args.encoder_embed_dim = 512
args.encoder_ffn_embed_dim = 2048
args.encoder_layers = 6
args.encoder_attention_heads = 8
args.decoder_embed_dim = 512
args.decoder_ffn_embed_dim = 2048
args.decoder_layers = 6
args.decoder_attention_heads = 8
# parameters used in the "Attention Is All You Need" paper (Vaswani, et al, 2017) # parameters used in the "Attention Is All You Need" paper (Vaswani, et al, 2017)
@register_model_architecture('transformer', 'transformer_vaswani_wmt_en_de_big') @register_model_architecture('transformer', 'transformer_vaswani_wmt_en_de_big')
def transformer_vaswani_wmt_en_de_big(args): def transformer_vaswani_wmt_en_de_big(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
args.dropout = getattr(args, 'dropout', 0.3)
base_architecture(args) base_architecture(args)
args.encoder_embed_dim = 1024
args.encoder_ffn_embed_dim = 4096
args.encoder_layers = 6
args.encoder_attention_heads = 16
args.decoder_embed_dim = 1024
args.decoder_ffn_embed_dim = 4096
args.decoder_layers = 6
args.decoder_attention_heads = 16
args.dropout = 0.3
@register_model_architecture('transformer', 'transformer_wmt_en_de_big') @register_model_architecture('transformer', 'transformer_wmt_en_de_big')
def transformer_wmt_en_de_big(args): def transformer_wmt_en_de_big(args):
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
transformer_vaswani_wmt_en_de_big(args) transformer_vaswani_wmt_en_de_big(args)
args.attention_dropout = 0.1
# default parameters used in tensor2tensor implementation # default parameters used in tensor2tensor implementation
@register_model_architecture('transformer', 'transformer_wmt_en_de_big_t2t') @register_model_architecture('transformer', 'transformer_wmt_en_de_big_t2t')
def transformer_wmt_en_de_big_t2t(args): def transformer_wmt_en_de_big_t2t(args):
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', True)
args.encoder_normalize_before = getattr(args, 'decoder_normalize_before', True)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.relu_dropout = getattr(args, 'relu_dropout', 0.1)
transformer_vaswani_wmt_en_de_big(args) transformer_vaswani_wmt_en_de_big(args)
args.encoder_normalize_before = True
args.decoder_normalize_before = True
args.attention_dropout = 0.1
args.relu_dropout = 0.1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment