Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
1d38624f
Commit
1d38624f
authored
Aug 04, 2018
by
alexeib
Committed by
Myle Ott
Sep 03, 2018
Browse files
parameters to separate input/inner/out dims
parent
e4f51e18
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
11 deletions
+40
-11
fairseq/models/transformer.py
fairseq/models/transformer.py
+40
-11
No files found.
fairseq/models/transformer.py
View file @
1d38624f
...
@@ -145,6 +145,10 @@ class TransformerLanguageModel(FairseqLanguageModel):
...
@@ -145,6 +145,10 @@ class TransformerLanguageModel(FairseqLanguageModel):
help
=
'dropout probability after ReLU in FFN'
)
help
=
'dropout probability after ReLU in FFN'
)
parser
.
add_argument
(
'--decoder-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
parser
.
add_argument
(
'--decoder-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'decoder embedding dimension'
)
help
=
'decoder embedding dimension'
)
parser
.
add_argument
(
'--decoder-output-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'decoder output dimension'
)
parser
.
add_argument
(
'--decoder-input-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'decoder input dimension'
)
parser
.
add_argument
(
'--decoder-ffn-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
parser
.
add_argument
(
'--decoder-ffn-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'decoder embedding dimension for FFN'
)
help
=
'decoder embedding dimension for FFN'
)
parser
.
add_argument
(
'--decoder-layers'
,
type
=
int
,
metavar
=
'N'
,
parser
.
add_argument
(
'--decoder-layers'
,
type
=
int
,
metavar
=
'N'
,
...
@@ -191,9 +195,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
...
@@ -191,9 +195,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
args
.
char_embedder_highway_layers
,
args
.
char_embedder_highway_layers
,
)
)
else
:
else
:
embed_tokens
=
Embedding
(
len
(
task
.
dictionary
),
args
.
decoder_embed_dim
,
task
.
dictionary
.
pad
())
embed_tokens
=
Embedding
(
len
(
task
.
dictionary
),
args
.
decoder_input_dim
,
task
.
dictionary
.
pad
())
print
(
args
)
decoder
=
TransformerDecoder
(
args
,
task
.
dictionary
,
embed_tokens
,
no_encoder_attn
=
True
)
decoder
=
TransformerDecoder
(
args
,
task
.
dictionary
,
embed_tokens
,
no_encoder_attn
=
True
)
return
TransformerLanguageModel
(
decoder
)
return
TransformerLanguageModel
(
decoder
)
...
@@ -291,12 +293,19 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -291,12 +293,19 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self
.
dropout
=
args
.
dropout
self
.
dropout
=
args
.
dropout
self
.
share_input_output_embed
=
args
.
share_decoder_input_output_embed
self
.
share_input_output_embed
=
args
.
share_decoder_input_output_embed
embed_dim
=
embed_tokens
.
embedding_dim
input_embed_dim
=
embed_tokens
.
embedding_dim
embed_dim
=
args
.
decoder_embed_dim
output_embed_dim
=
args
.
decoder_output_dim
padding_idx
=
embed_tokens
.
padding_idx
padding_idx
=
embed_tokens
.
padding_idx
self
.
max_target_positions
=
args
.
max_target_positions
self
.
max_target_positions
=
args
.
max_target_positions
self
.
embed_tokens
=
embed_tokens
self
.
embed_tokens
=
embed_tokens
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
# todo: try with input_embed_dim
self
.
project_in_dim
=
Linear
(
input_embed_dim
,
embed_dim
,
bias
=
False
,
uniform
=
False
)
if
embed_dim
!=
input_embed_dim
else
None
self
.
embed_positions
=
PositionalEmbedding
(
self
.
embed_positions
=
PositionalEmbedding
(
args
.
max_target_positions
,
embed_dim
,
padding_idx
,
args
.
max_target_positions
,
embed_dim
,
padding_idx
,
left_pad
=
left_pad
,
left_pad
=
left_pad
,
...
@@ -311,15 +320,18 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -311,15 +320,18 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self
.
adaptive_softmax
=
None
self
.
adaptive_softmax
=
None
self
.
project_out_dim
=
Linear
(
embed_dim
,
output_embed_dim
,
bias
=
False
,
uniform
=
False
)
if
embed_dim
!=
output_embed_dim
else
None
if
args
.
adaptive_softmax_cutoff
is
not
None
:
if
args
.
adaptive_softmax_cutoff
is
not
None
:
self
.
adaptive_softmax
=
AdaptiveSoftmax
(
self
.
adaptive_softmax
=
AdaptiveSoftmax
(
len
(
dictionary
),
args
.
decoder
_embed_dim
,
len
(
dictionary
),
output
_embed_dim
,
options
.
eval_str_list
(
args
.
adaptive_softmax_cutoff
,
type
=
int
),
options
.
eval_str_list
(
args
.
adaptive_softmax_cutoff
,
type
=
int
),
dropout
=
args
.
adaptive_softmax_dropout
,
dropout
=
args
.
adaptive_softmax_dropout
,
)
)
elif
not
self
.
share_input_output_embed
:
elif
not
self
.
share_input_output_embed
:
self
.
embed_out
=
nn
.
Parameter
(
torch
.
Tensor
(
len
(
dictionary
),
embed_dim
))
self
.
embed_out
=
nn
.
Parameter
(
torch
.
Tensor
(
len
(
dictionary
),
output_
embed_dim
))
nn
.
init
.
normal_
(
self
.
embed_out
,
mean
=
0
,
std
=
embed_dim
**
-
0.5
)
nn
.
init
.
normal_
(
self
.
embed_out
,
mean
=
0
,
std
=
output_
embed_dim
**
-
0.5
)
self
.
register_buffer
(
'version'
,
torch
.
Tensor
([
2
]))
self
.
register_buffer
(
'version'
,
torch
.
Tensor
([
2
]))
self
.
normalize
=
args
.
decoder_normalize_before
self
.
normalize
=
args
.
decoder_normalize_before
if
self
.
normalize
:
if
self
.
normalize
:
...
@@ -339,6 +351,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -339,6 +351,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
# embed tokens and positions
# embed tokens and positions
x
=
self
.
embed_scale
*
self
.
embed_tokens
(
prev_output_tokens
)
x
=
self
.
embed_scale
*
self
.
embed_tokens
(
prev_output_tokens
)
if
self
.
project_in_dim
is
not
None
:
x
=
self
.
project_in_dim
(
x
)
if
positions
is
not
None
:
if
positions
is
not
None
:
x
+=
positions
x
+=
positions
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
...
@@ -362,6 +378,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -362,6 +378,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
# T x B x C -> B x T x C
# T x B x C -> B x T x C
x
=
x
.
transpose
(
0
,
1
)
x
=
x
.
transpose
(
0
,
1
)
if
self
.
project_out_dim
is
not
None
:
x
=
self
.
project_out_dim
(
x
)
if
self
.
adaptive_softmax
is
None
:
if
self
.
adaptive_softmax
is
None
:
# project back to size of vocabulary
# project back to size of vocabulary
if
self
.
share_input_output_embed
:
if
self
.
share_input_output_embed
:
...
@@ -555,10 +574,14 @@ def LayerNorm(embedding_dim):
...
@@ -555,10 +574,14 @@ def LayerNorm(embedding_dim):
return
m
return
m
def
Linear
(
in_features
,
out_features
,
bias
=
True
):
def
Linear
(
in_features
,
out_features
,
bias
=
True
,
uniform
=
True
):
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
)
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
)
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
if
uniform
:
nn
.
init
.
constant_
(
m
.
bias
,
0.
)
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
else
:
nn
.
init
.
xavier_normal_
(
m
.
weight
)
if
bias
:
nn
.
init
.
constant_
(
m
.
bias
,
0.
)
return
m
return
m
...
@@ -584,6 +607,9 @@ def base_lm_architecture(args):
...
@@ -584,6 +607,9 @@ def base_lm_architecture(args):
args
.
character_embeddings
=
getattr
(
args
,
'character_embeddings'
,
False
)
args
.
character_embeddings
=
getattr
(
args
,
'character_embeddings'
,
False
)
args
.
decoder_output_dim
=
getattr
(
args
,
'decoder_output_dim'
,
args
.
decoder_embed_dim
)
args
.
decoder_input_dim
=
getattr
(
args
,
'decoder_input_dim'
,
args
.
decoder_embed_dim
)
# The model training is not stable without this
# The model training is not stable without this
args
.
decoder_normalize_before
=
True
args
.
decoder_normalize_before
=
True
...
@@ -635,6 +661,9 @@ def base_architecture(args):
...
@@ -635,6 +661,9 @@ def base_architecture(args):
args
.
share_all_embeddings
=
getattr
(
args
,
'share_all_embeddings'
,
False
)
args
.
share_all_embeddings
=
getattr
(
args
,
'share_all_embeddings'
,
False
)
args
.
no_token_positional_embeddings
=
getattr
(
args
,
'no_token_positional_embeddings'
,
False
)
args
.
no_token_positional_embeddings
=
getattr
(
args
,
'no_token_positional_embeddings'
,
False
)
args
.
decoder_output_dim
=
getattr
(
args
,
'decoder_output_dim'
,
args
.
decoder_embed_dim
)
args
.
decoder_input_dim
=
getattr
(
args
,
'decoder_input_dim'
,
args
.
decoder_embed_dim
)
@
register_model_architecture
(
'transformer'
,
'transformer_iwslt_de_en'
)
@
register_model_architecture
(
'transformer'
,
'transformer_iwslt_de_en'
)
def
transformer_iwslt_de_en
(
args
):
def
transformer_iwslt_de_en
(
args
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment