Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
1235aa08
Commit
1235aa08
authored
Mar 13, 2018
by
Myle Ott
Browse files
Pass args around to cleanup parameter lists
parent
559eca81
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
80 deletions
+42
-80
fairseq/models/transformer.py
fairseq/models/transformer.py
+42
-80
No files found.
fairseq/models/transformer.py
View file @
1235aa08
...
@@ -89,44 +89,16 @@ class TransformerModel(FairseqModel):
...
@@ -89,44 +89,16 @@ class TransformerModel(FairseqModel):
encoder_embed_tokens
=
build_embedding
(
src_dict
,
args
.
encoder_embed_dim
)
encoder_embed_tokens
=
build_embedding
(
src_dict
,
args
.
encoder_embed_dim
)
decoder_embed_tokens
=
build_embedding
(
dst_dict
,
args
.
decoder_embed_dim
)
decoder_embed_tokens
=
build_embedding
(
dst_dict
,
args
.
decoder_embed_dim
)
encoder
=
TransformerEncoder
(
encoder
=
TransformerEncoder
(
args
,
src_dict
,
encoder_embed_tokens
)
src_dict
,
decoder
=
TransformerDecoder
(
args
,
dst_dict
,
decoder_embed_tokens
)
encoder_embed_tokens
,
ffn_inner_dim
=
args
.
encoder_ffn_embed_dim
,
num_layers
=
args
.
encoder_layers
,
num_attn_heads
=
args
.
encoder_attention_heads
,
dropout
=
args
.
dropout
,
attention_dropout
=
args
.
attention_dropout
,
relu_dropout
=
args
.
relu_dropout
,
normalize_before
=
args
.
encoder_normalize_before
,
learned_pos_embed
=
args
.
encoder_learned_pos
,
)
decoder
=
TransformerDecoder
(
dst_dict
,
decoder_embed_tokens
,
ffn_inner_dim
=
args
.
decoder_ffn_embed_dim
,
num_layers
=
args
.
decoder_layers
,
num_attn_heads
=
args
.
decoder_attention_heads
,
dropout
=
args
.
dropout
,
attention_dropout
=
args
.
attention_dropout
,
relu_dropout
=
args
.
relu_dropout
,
normalize_before
=
args
.
encoder_normalize_before
,
learned_pos_embed
=
args
.
decoder_learned_pos
,
share_input_output_embed
=
args
.
share_decoder_input_output_embed
,
)
return
TransformerModel
(
encoder
,
decoder
)
return
TransformerModel
(
encoder
,
decoder
)
class
TransformerEncoder
(
FairseqEncoder
):
class
TransformerEncoder
(
FairseqEncoder
):
"""Transformer encoder."""
"""Transformer encoder."""
def
__init__
(
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
):
self
,
dictionary
,
embed_tokens
,
ffn_inner_dim
=
2048
,
num_layers
=
6
,
num_attn_heads
=
8
,
dropout
=
0.1
,
attention_dropout
=
0.
,
relu_dropout
=
0.
,
normalize_before
=
False
,
learned_pos_embed
=
False
,
):
super
().
__init__
(
dictionary
)
super
().
__init__
(
dictionary
)
self
.
dropout
=
dropout
self
.
dropout
=
args
.
dropout
embed_dim
=
embed_tokens
.
embedding_dim
embed_dim
=
embed_tokens
.
embedding_dim
self
.
padding_idx
=
embed_tokens
.
padding_idx
self
.
padding_idx
=
embed_tokens
.
padding_idx
...
@@ -136,17 +108,13 @@ class TransformerEncoder(FairseqEncoder):
...
@@ -136,17 +108,13 @@ class TransformerEncoder(FairseqEncoder):
self
.
embed_positions
=
PositionalEmbedding
(
self
.
embed_positions
=
PositionalEmbedding
(
1024
,
embed_dim
,
self
.
padding_idx
,
1024
,
embed_dim
,
self
.
padding_idx
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_SOURCE
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_SOURCE
,
learned
=
learned_pos
_embed
,
learned
=
args
.
encoder_
learned_pos
,
)
)
self
.
layers
=
nn
.
ModuleList
([])
self
.
layers
=
nn
.
ModuleList
([])
self
.
layers
.
extend
([
self
.
layers
.
extend
([
TransformerEncoderLayer
(
TransformerEncoderLayer
(
args
)
embed_dim
,
ffn_inner_dim
,
num_attn_heads
,
dropout
=
dropout
,
for
i
in
range
(
args
.
encoder_layers
)
attention_dropout
=
attention_dropout
,
relu_dropout
=
relu_dropout
,
normalize_before
=
normalize_before
,
)
for
i
in
range
(
num_layers
)
])
])
self
.
reset_parameters
()
self
.
reset_parameters
()
...
@@ -186,15 +154,10 @@ class TransformerEncoder(FairseqEncoder):
...
@@ -186,15 +154,10 @@ class TransformerEncoder(FairseqEncoder):
class
TransformerDecoder
(
FairseqDecoder
):
class
TransformerDecoder
(
FairseqDecoder
):
"""Transformer decoder."""
"""Transformer decoder."""
def
__init__
(
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
):
self
,
dictionary
,
embed_tokens
,
ffn_inner_dim
=
2048
,
num_layers
=
6
,
num_attn_heads
=
8
,
dropout
=
0.1
,
attention_dropout
=
0.
,
relu_dropout
=
0.
,
normalize_before
=
False
,
learned_pos_embed
=
False
,
share_input_output_embed
=
False
,
):
super
().
__init__
(
dictionary
)
super
().
__init__
(
dictionary
)
self
.
dropout
=
dropout
self
.
dropout
=
args
.
dropout
self
.
share_input_output_embed
=
share
_input_output_embed
self
.
share_input_output_embed
=
args
.
share_decoder
_input_output_embed
embed_dim
=
embed_tokens
.
embedding_dim
embed_dim
=
embed_tokens
.
embedding_dim
padding_idx
=
embed_tokens
.
padding_idx
padding_idx
=
embed_tokens
.
padding_idx
...
@@ -204,20 +167,16 @@ class TransformerDecoder(FairseqDecoder):
...
@@ -204,20 +167,16 @@ class TransformerDecoder(FairseqDecoder):
self
.
embed_positions
=
PositionalEmbedding
(
self
.
embed_positions
=
PositionalEmbedding
(
1024
,
embed_dim
,
padding_idx
,
1024
,
embed_dim
,
padding_idx
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
,
learned
=
learned_pos
_embed
,
learned
=
args
.
decoder_
learned_pos
,
)
)
self
.
layers
=
nn
.
ModuleList
([])
self
.
layers
=
nn
.
ModuleList
([])
self
.
layers
.
extend
([
self
.
layers
.
extend
([
TransformerDecoderLayer
(
TransformerDecoderLayer
(
args
)
embed_dim
,
ffn_inner_dim
,
num_attn_heads
,
dropout
=
dropout
,
for
i
in
range
(
args
.
decoder_layers
)
attention_dropout
=
attention_dropout
,
relu_dropout
=
relu_dropout
,
normalize_before
=
normalize_before
,
)
for
i
in
range
(
num_layers
)
])
])
if
not
share_input_output_embed
:
if
not
self
.
share_input_output_embed
:
self
.
embed_out
=
nn
.
Parameter
(
torch
.
Tensor
(
len
(
dictionary
),
embed_dim
))
self
.
embed_out
=
nn
.
Parameter
(
torch
.
Tensor
(
len
(
dictionary
),
embed_dim
))
self
.
reset_parameters
()
self
.
reset_parameters
()
...
@@ -276,19 +235,19 @@ class TransformerEncoderLayer(nn.Module):
...
@@ -276,19 +235,19 @@ class TransformerEncoderLayer(nn.Module):
We default to the approach in the paper, but the tensor2tensor approach can
We default to the approach in the paper, but the tensor2tensor approach can
be enabled by setting `normalize_before=True`.
be enabled by setting `normalize_before=True`.
"""
"""
def
__init__
(
def
__init__
(
self
,
args
):
self
,
embed_dim
,
ffn_inner_dim
,
num_attn_heads
,
dropout
=
0.1
,
attention_dropout
=
0.
,
relu_dropout
=
0.
,
normalize_before
=
False
,
):
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
embed_dim
=
args
.
encoder_embed_dim
self
.
self_attn
=
MultiheadAttention
(
embed_dim
,
num_attn_heads
,
dropout
=
attention_dropout
)
self
.
self_attn
=
MultiheadAttention
(
self
.
dropout
=
dropout
self
.
embed_dim
,
args
.
encoder_attention_heads
,
self
.
relu_dropout
=
relu_dropout
dropout
=
args
.
attention_dropout
,
self
.
normalize_before
=
normalize_before
)
self
.
fc1
=
nn
.
Linear
(
embed_dim
,
ffn_inner_dim
)
self
.
dropout
=
args
.
dropout
self
.
fc2
=
nn
.
Linear
(
ffn_inner_dim
,
embed_dim
)
self
.
relu_dropout
=
args
.
relu_dropout
self
.
layer_norms
=
nn
.
ModuleList
([
LayerNorm
(
embed_dim
)
for
i
in
range
(
2
)])
self
.
normalize_before
=
args
.
encoder_normalize_before
self
.
fc1
=
nn
.
Linear
(
self
.
embed_dim
,
args
.
encoder_ffn_embed_dim
)
self
.
fc2
=
nn
.
Linear
(
args
.
encoder_ffn_embed_dim
,
self
.
embed_dim
)
self
.
layer_norms
=
nn
.
ModuleList
([
LayerNorm
(
self
.
embed_dim
)
for
i
in
range
(
2
)])
def
forward
(
self
,
x
,
encoder_padding_mask
):
def
forward
(
self
,
x
,
encoder_padding_mask
):
residual
=
x
residual
=
x
...
@@ -318,20 +277,23 @@ class TransformerEncoderLayer(nn.Module):
...
@@ -318,20 +277,23 @@ class TransformerEncoderLayer(nn.Module):
class
TransformerDecoderLayer
(
nn
.
Module
):
class
TransformerDecoderLayer
(
nn
.
Module
):
"""Decoder layer block."""
"""Decoder layer block."""
def
__init__
(
def
__init__
(
self
,
args
):
self
,
embed_dim
,
ffn_inner_dim
,
num_attn_heads
,
dropout
=
0.1
,
attention_dropout
=
0.
,
relu_dropout
=
0.
,
normalize_before
=
False
,
):
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
embed_dim
=
args
.
decoder_embed_dim
self
.
self_attn
=
MultiheadAttention
(
embed_dim
,
num_attn_heads
,
dropout
=
attention_dropout
)
self
.
self_attn
=
MultiheadAttention
(
self
.
dropout
=
dropout
self
.
embed_dim
,
args
.
decoder_attention_heads
,
self
.
relu_dropout
=
relu_dropout
dropout
=
args
.
attention_dropout
,
self
.
normalize_before
=
normalize_before
)
self
.
encoder_attn
=
MultiheadAttention
(
embed_dim
,
num_attn_heads
,
dropout
=
attention_dropout
)
self
.
dropout
=
args
.
dropout
self
.
fc1
=
nn
.
Linear
(
embed_dim
,
ffn_inner_dim
)
self
.
relu_dropout
=
args
.
relu_dropout
self
.
fc2
=
nn
.
Linear
(
ffn_inner_dim
,
embed_dim
)
self
.
normalize_before
=
args
.
encoder_normalize_before
self
.
layer_norms
=
nn
.
ModuleList
([
LayerNorm
(
embed_dim
)
for
i
in
range
(
3
)])
self
.
encoder_attn
=
MultiheadAttention
(
self
.
embed_dim
,
args
.
decoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
)
self
.
fc1
=
nn
.
Linear
(
self
.
embed_dim
,
args
.
decoder_ffn_embed_dim
)
self
.
fc2
=
nn
.
Linear
(
args
.
decoder_ffn_embed_dim
,
self
.
embed_dim
)
self
.
layer_norms
=
nn
.
ModuleList
([
LayerNorm
(
self
.
embed_dim
)
for
i
in
range
(
3
)])
def
forward
(
self
,
x
,
encoder_out
,
encoder_padding_mask
):
def
forward
(
self
,
x
,
encoder_out
,
encoder_padding_mask
):
residual
=
x
residual
=
x
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment