Commit 1235aa08 authored by Myle Ott's avatar Myle Ott
Browse files

Pass args around to cleanup parameter lists

parent 559eca81
...@@ -89,44 +89,16 @@ class TransformerModel(FairseqModel): ...@@ -89,44 +89,16 @@ class TransformerModel(FairseqModel):
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim) encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
decoder_embed_tokens = build_embedding(dst_dict, args.decoder_embed_dim) decoder_embed_tokens = build_embedding(dst_dict, args.decoder_embed_dim)
encoder = TransformerEncoder( encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens)
src_dict, decoder = TransformerDecoder(args, dst_dict, decoder_embed_tokens)
encoder_embed_tokens,
ffn_inner_dim=args.encoder_ffn_embed_dim,
num_layers=args.encoder_layers,
num_attn_heads=args.encoder_attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
relu_dropout=args.relu_dropout,
normalize_before=args.encoder_normalize_before,
learned_pos_embed=args.encoder_learned_pos,
)
decoder = TransformerDecoder(
dst_dict,
decoder_embed_tokens,
ffn_inner_dim=args.decoder_ffn_embed_dim,
num_layers=args.decoder_layers,
num_attn_heads=args.decoder_attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
relu_dropout=args.relu_dropout,
normalize_before=args.encoder_normalize_before,
learned_pos_embed=args.decoder_learned_pos,
share_input_output_embed=args.share_decoder_input_output_embed,
)
return TransformerModel(encoder, decoder) return TransformerModel(encoder, decoder)
class TransformerEncoder(FairseqEncoder): class TransformerEncoder(FairseqEncoder):
"""Transformer encoder.""" """Transformer encoder."""
def __init__( def __init__(self, args, dictionary, embed_tokens):
self, dictionary, embed_tokens, ffn_inner_dim=2048,
num_layers=6, num_attn_heads=8, dropout=0.1, attention_dropout=0.,
relu_dropout=0., normalize_before=False, learned_pos_embed=False,
):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout = dropout self.dropout = args.dropout
embed_dim = embed_tokens.embedding_dim embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx self.padding_idx = embed_tokens.padding_idx
...@@ -136,17 +108,13 @@ class TransformerEncoder(FairseqEncoder): ...@@ -136,17 +108,13 @@ class TransformerEncoder(FairseqEncoder):
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
1024, embed_dim, self.padding_idx, 1024, embed_dim, self.padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE, left_pad=LanguagePairDataset.LEFT_PAD_SOURCE,
learned=learned_pos_embed, learned=args.encoder_learned_pos,
) )
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.layers.extend([ self.layers.extend([
TransformerEncoderLayer( TransformerEncoderLayer(args)
embed_dim, ffn_inner_dim, num_attn_heads, dropout=dropout, for i in range(args.encoder_layers)
attention_dropout=attention_dropout, relu_dropout=relu_dropout,
normalize_before=normalize_before,
)
for i in range(num_layers)
]) ])
self.reset_parameters() self.reset_parameters()
...@@ -186,15 +154,10 @@ class TransformerEncoder(FairseqEncoder): ...@@ -186,15 +154,10 @@ class TransformerEncoder(FairseqEncoder):
class TransformerDecoder(FairseqDecoder): class TransformerDecoder(FairseqDecoder):
"""Transformer decoder.""" """Transformer decoder."""
def __init__( def __init__(self, args, dictionary, embed_tokens):
self, dictionary, embed_tokens, ffn_inner_dim=2048,
num_layers=6, num_attn_heads=8, dropout=0.1, attention_dropout=0.,
relu_dropout=0., normalize_before=False, learned_pos_embed=False,
share_input_output_embed=False,
):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout = dropout self.dropout = args.dropout
self.share_input_output_embed = share_input_output_embed self.share_input_output_embed = args.share_decoder_input_output_embed
embed_dim = embed_tokens.embedding_dim embed_dim = embed_tokens.embedding_dim
padding_idx = embed_tokens.padding_idx padding_idx = embed_tokens.padding_idx
...@@ -204,20 +167,16 @@ class TransformerDecoder(FairseqDecoder): ...@@ -204,20 +167,16 @@ class TransformerDecoder(FairseqDecoder):
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
1024, embed_dim, padding_idx, 1024, embed_dim, padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_TARGET, left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
learned=learned_pos_embed, learned=args.decoder_learned_pos,
) )
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.layers.extend([ self.layers.extend([
TransformerDecoderLayer( TransformerDecoderLayer(args)
embed_dim, ffn_inner_dim, num_attn_heads, dropout=dropout, for i in range(args.decoder_layers)
attention_dropout=attention_dropout, relu_dropout=relu_dropout,
normalize_before=normalize_before,
)
for i in range(num_layers)
]) ])
if not share_input_output_embed: if not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim)) self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
self.reset_parameters() self.reset_parameters()
...@@ -276,19 +235,19 @@ class TransformerEncoderLayer(nn.Module): ...@@ -276,19 +235,19 @@ class TransformerEncoderLayer(nn.Module):
We default to the approach in the paper, but the tensor2tensor approach can We default to the approach in the paper, but the tensor2tensor approach can
be enabled by setting `normalize_before=True`. be enabled by setting `normalize_before=True`.
""" """
def __init__( def __init__(self, args):
self, embed_dim, ffn_inner_dim, num_attn_heads, dropout=0.1,
attention_dropout=0., relu_dropout=0., normalize_before=False,
):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = args.encoder_embed_dim
self.self_attn = MultiheadAttention(embed_dim, num_attn_heads, dropout=attention_dropout) self.self_attn = MultiheadAttention(
self.dropout = dropout self.embed_dim, args.encoder_attention_heads,
self.relu_dropout = relu_dropout dropout=args.attention_dropout,
self.normalize_before = normalize_before )
self.fc1 = nn.Linear(embed_dim, ffn_inner_dim) self.dropout = args.dropout
self.fc2 = nn.Linear(ffn_inner_dim, embed_dim) self.relu_dropout = args.relu_dropout
self.layer_norms = nn.ModuleList([LayerNorm(embed_dim) for i in range(2)]) self.normalize_before = args.encoder_normalize_before
self.fc1 = nn.Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = nn.Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(2)])
def forward(self, x, encoder_padding_mask): def forward(self, x, encoder_padding_mask):
residual = x residual = x
...@@ -318,20 +277,23 @@ class TransformerEncoderLayer(nn.Module): ...@@ -318,20 +277,23 @@ class TransformerEncoderLayer(nn.Module):
class TransformerDecoderLayer(nn.Module): class TransformerDecoderLayer(nn.Module):
"""Decoder layer block.""" """Decoder layer block."""
def __init__( def __init__(self, args):
self, embed_dim, ffn_inner_dim, num_attn_heads, dropout=0.1,
attention_dropout=0., relu_dropout=0., normalize_before=False,
):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention(embed_dim, num_attn_heads, dropout=attention_dropout) self.self_attn = MultiheadAttention(
self.dropout = dropout self.embed_dim, args.decoder_attention_heads,
self.relu_dropout = relu_dropout dropout=args.attention_dropout,
self.normalize_before = normalize_before )
self.encoder_attn = MultiheadAttention(embed_dim, num_attn_heads, dropout=attention_dropout) self.dropout = args.dropout
self.fc1 = nn.Linear(embed_dim, ffn_inner_dim) self.relu_dropout = args.relu_dropout
self.fc2 = nn.Linear(ffn_inner_dim, embed_dim) self.normalize_before = args.encoder_normalize_before
self.layer_norms = nn.ModuleList([LayerNorm(embed_dim) for i in range(3)]) self.encoder_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.fc1 = nn.Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = nn.Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])
def forward(self, x, encoder_out, encoder_padding_mask): def forward(self, x, encoder_out, encoder_padding_mask):
residual = x residual = x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment