"docs/source/vscode:/vscode.git/clone" did not exist on "1d3de6452fdfc9ddbec9f58a7ad8c6da02408bca"
Commit 1235aa08 authored by Myle Ott's avatar Myle Ott
Browse files

Pass args around to cleanup parameter lists

parent 559eca81
......@@ -89,44 +89,16 @@ class TransformerModel(FairseqModel):
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
decoder_embed_tokens = build_embedding(dst_dict, args.decoder_embed_dim)
encoder = TransformerEncoder(
src_dict,
encoder_embed_tokens,
ffn_inner_dim=args.encoder_ffn_embed_dim,
num_layers=args.encoder_layers,
num_attn_heads=args.encoder_attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
relu_dropout=args.relu_dropout,
normalize_before=args.encoder_normalize_before,
learned_pos_embed=args.encoder_learned_pos,
)
decoder = TransformerDecoder(
dst_dict,
decoder_embed_tokens,
ffn_inner_dim=args.decoder_ffn_embed_dim,
num_layers=args.decoder_layers,
num_attn_heads=args.decoder_attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
relu_dropout=args.relu_dropout,
normalize_before=args.encoder_normalize_before,
learned_pos_embed=args.decoder_learned_pos,
share_input_output_embed=args.share_decoder_input_output_embed,
)
encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens)
decoder = TransformerDecoder(args, dst_dict, decoder_embed_tokens)
return TransformerModel(encoder, decoder)
class TransformerEncoder(FairseqEncoder):
"""Transformer encoder."""
def __init__(
self, dictionary, embed_tokens, ffn_inner_dim=2048,
num_layers=6, num_attn_heads=8, dropout=0.1, attention_dropout=0.,
relu_dropout=0., normalize_before=False, learned_pos_embed=False,
):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary)
self.dropout = dropout
self.dropout = args.dropout
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
......@@ -136,17 +108,13 @@ class TransformerEncoder(FairseqEncoder):
self.embed_positions = PositionalEmbedding(
1024, embed_dim, self.padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE,
learned=learned_pos_embed,
learned=args.encoder_learned_pos,
)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerEncoderLayer(
embed_dim, ffn_inner_dim, num_attn_heads, dropout=dropout,
attention_dropout=attention_dropout, relu_dropout=relu_dropout,
normalize_before=normalize_before,
)
for i in range(num_layers)
TransformerEncoderLayer(args)
for i in range(args.encoder_layers)
])
self.reset_parameters()
......@@ -186,15 +154,10 @@ class TransformerEncoder(FairseqEncoder):
class TransformerDecoder(FairseqDecoder):
"""Transformer decoder."""
def __init__(
self, dictionary, embed_tokens, ffn_inner_dim=2048,
num_layers=6, num_attn_heads=8, dropout=0.1, attention_dropout=0.,
relu_dropout=0., normalize_before=False, learned_pos_embed=False,
share_input_output_embed=False,
):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary)
self.dropout = dropout
self.share_input_output_embed = share_input_output_embed
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
embed_dim = embed_tokens.embedding_dim
padding_idx = embed_tokens.padding_idx
......@@ -204,20 +167,16 @@ class TransformerDecoder(FairseqDecoder):
self.embed_positions = PositionalEmbedding(
1024, embed_dim, padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
learned=learned_pos_embed,
learned=args.decoder_learned_pos,
)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerDecoderLayer(
embed_dim, ffn_inner_dim, num_attn_heads, dropout=dropout,
attention_dropout=attention_dropout, relu_dropout=relu_dropout,
normalize_before=normalize_before,
)
for i in range(num_layers)
TransformerDecoderLayer(args)
for i in range(args.decoder_layers)
])
if not share_input_output_embed:
if not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
self.reset_parameters()
......@@ -276,19 +235,19 @@ class TransformerEncoderLayer(nn.Module):
We default to the approach in the paper, but the tensor2tensor approach can
be enabled by setting `normalize_before=True`.
"""
def __init__(
self, embed_dim, ffn_inner_dim, num_attn_heads, dropout=0.1,
attention_dropout=0., relu_dropout=0., normalize_before=False,
):
def __init__(self, args):
super().__init__()
self.embed_dim = embed_dim
self.self_attn = MultiheadAttention(embed_dim, num_attn_heads, dropout=attention_dropout)
self.dropout = dropout
self.relu_dropout = relu_dropout
self.normalize_before = normalize_before
self.fc1 = nn.Linear(embed_dim, ffn_inner_dim)
self.fc2 = nn.Linear(ffn_inner_dim, embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(embed_dim) for i in range(2)])
self.embed_dim = args.encoder_embed_dim
self.self_attn = MultiheadAttention(
self.embed_dim, args.encoder_attention_heads,
dropout=args.attention_dropout,
)
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.normalize_before = args.encoder_normalize_before
self.fc1 = nn.Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = nn.Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(2)])
def forward(self, x, encoder_padding_mask):
residual = x
......@@ -318,20 +277,23 @@ class TransformerEncoderLayer(nn.Module):
class TransformerDecoderLayer(nn.Module):
"""Decoder layer block."""
def __init__(
self, embed_dim, ffn_inner_dim, num_attn_heads, dropout=0.1,
attention_dropout=0., relu_dropout=0., normalize_before=False,
):
def __init__(self, args):
super().__init__()
self.embed_dim = embed_dim
self.self_attn = MultiheadAttention(embed_dim, num_attn_heads, dropout=attention_dropout)
self.dropout = dropout
self.relu_dropout = relu_dropout
self.normalize_before = normalize_before
self.encoder_attn = MultiheadAttention(embed_dim, num_attn_heads, dropout=attention_dropout)
self.fc1 = nn.Linear(embed_dim, ffn_inner_dim)
self.fc2 = nn.Linear(ffn_inner_dim, embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(embed_dim) for i in range(3)])
self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.normalize_before = args.encoder_normalize_before
self.encoder_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.fc1 = nn.Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = nn.Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])
def forward(self, x, encoder_out, encoder_padding_mask):
residual = x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment