Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
9f3ccaa6
"vscode:/vscode.git/clone" did not exist on "e602ab1b56889c8f999f07aeddb55d641fba1014"
Commit
9f3ccaa6
authored
Dec 04, 2017
by
Myle Ott
Browse files
Fix weight norm dimension in decoder (fixes #73)
parent
99493a85
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
1 deletion
+25
-1
fairseq/models/fairseq_decoder.py
fairseq/models/fairseq_decoder.py
+3
-0
fairseq/models/fairseq_encoder.py
fairseq/models/fairseq_encoder.py
+3
-0
fairseq/models/fairseq_model.py
fairseq/models/fairseq_model.py
+5
-0
fairseq/models/fconv.py
fairseq/models/fconv.py
+12
-1
fairseq/utils.py
fairseq/utils.py
+2
-0
No files found.
fairseq/models/fairseq_decoder.py
View file @
9f3ccaa6
...
@@ -18,3 +18,6 @@ class FairseqDecoder(nn.Module):
...
@@ -18,3 +18,6 @@ class FairseqDecoder(nn.Module):
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum input length supported by the decoder."""
"""Maximum input length supported by the decoder."""
raise
NotImplementedError
raise
NotImplementedError
def
upgrade_state_dict
(
self
,
state_dict
):
return
state_dict
fairseq/models/fairseq_encoder.py
View file @
9f3ccaa6
...
@@ -18,3 +18,6 @@ class FairseqEncoder(nn.Module):
...
@@ -18,3 +18,6 @@ class FairseqEncoder(nn.Module):
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
"""Maximum input length supported by the encoder."""
raise
NotImplementedError
raise
NotImplementedError
def
upgrade_state_dict
(
self
,
state_dict
):
return
state_dict
fairseq/models/fairseq_model.py
View file @
9f3ccaa6
...
@@ -43,6 +43,11 @@ class FairseqModel(nn.Module):
...
@@ -43,6 +43,11 @@ class FairseqModel(nn.Module):
"""Maximum output length supported by the decoder."""
"""Maximum output length supported by the decoder."""
return
self
.
decoder
.
max_positions
()
return
self
.
decoder
.
max_positions
()
def
upgrade_state_dict
(
self
,
state_dict
):
state_dict
=
self
.
encoder
.
upgrade_state_dict
(
state_dict
)
state_dict
=
self
.
decoder
.
upgrade_state_dict
(
state_dict
)
return
state_dict
def
make_generation_fast_
(
self
,
**
kwargs
):
def
make_generation_fast_
(
self
,
**
kwargs
):
"""Optimize model for faster generation."""
"""Optimize model for faster generation."""
if
self
.
_is_generation_fast
:
if
self
.
_is_generation_fast
:
...
...
fairseq/models/fconv.py
View file @
9f3ccaa6
...
@@ -154,6 +154,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -154,6 +154,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
max_positions
=
1024
,
convolutions
=
((
512
,
3
),)
*
20
,
max_positions
=
1024
,
convolutions
=
((
512
,
3
),)
*
20
,
attention
=
True
,
dropout
=
0.1
):
attention
=
True
,
dropout
=
0.1
):
super
().
__init__
()
super
().
__init__
()
self
.
register_buffer
(
'version'
,
torch
.
Tensor
([
2
]))
self
.
dictionary
=
dictionary
self
.
dictionary
=
dictionary
self
.
dropout
=
dropout
self
.
dropout
=
dropout
...
@@ -265,6 +266,16 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -265,6 +266,16 @@ class FConvDecoder(FairseqIncrementalDecoder):
"""Maximum output length supported by the decoder."""
"""Maximum output length supported by the decoder."""
return
self
.
embed_positions
.
num_embeddings
-
self
.
dictionary
.
pad
()
-
1
return
self
.
embed_positions
.
num_embeddings
-
self
.
dictionary
.
pad
()
-
1
def
upgrade_state_dict
(
self
,
state_dict
):
if
state_dict
.
get
(
'decoder.version'
,
torch
.
Tensor
([
1
]))[
0
]
<
2
:
# old models use incorrect weight norm dimension
for
i
,
conv
in
enumerate
(
self
.
convolutions
):
# reconfigure weight norm
nn
.
utils
.
remove_weight_norm
(
conv
)
self
.
convolutions
[
i
]
=
nn
.
utils
.
weight_norm
(
conv
,
dim
=
0
)
state_dict
[
'decoder.version'
]
=
torch
.
Tensor
([
1
])
return
state_dict
def
_split_encoder_out
(
self
,
encoder_out
):
def
_split_encoder_out
(
self
,
encoder_out
):
"""Split and transpose encoder outputs.
"""Split and transpose encoder outputs.
...
@@ -307,7 +318,7 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs
...
@@ -307,7 +318,7 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs
std
=
math
.
sqrt
((
4
*
(
1.0
-
dropout
))
/
(
m
.
kernel_size
[
0
]
*
in_channels
))
std
=
math
.
sqrt
((
4
*
(
1.0
-
dropout
))
/
(
m
.
kernel_size
[
0
]
*
in_channels
))
m
.
weight
.
data
.
normal_
(
mean
=
0
,
std
=
std
)
m
.
weight
.
data
.
normal_
(
mean
=
0
,
std
=
std
)
m
.
bias
.
data
.
zero_
()
m
.
bias
.
data
.
zero_
()
return
nn
.
utils
.
weight_norm
(
m
)
return
nn
.
utils
.
weight_norm
(
m
,
dim
=
2
)
def
ConvTBC
(
in_channels
,
out_channels
,
kernel_size
,
dropout
=
0
,
**
kwargs
):
def
ConvTBC
(
in_channels
,
out_channels
,
kernel_size
,
dropout
=
0
,
**
kwargs
):
...
...
fairseq/utils.py
View file @
9f3ccaa6
...
@@ -94,6 +94,7 @@ def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=
...
@@ -94,6 +94,7 @@ def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=
map_location
=
lambda
s
,
l
:
default_restore_location
(
s
,
'cuda:{}'
.
format
(
cuda_device
))
map_location
=
lambda
s
,
l
:
default_restore_location
(
s
,
'cuda:{}'
.
format
(
cuda_device
))
)
)
state
=
_upgrade_state_dict
(
state
)
state
=
_upgrade_state_dict
(
state
)
state
[
'model'
]
=
model
.
upgrade_state_dict
(
state
[
'model'
])
# load model parameters
# load model parameters
try
:
try
:
...
@@ -168,6 +169,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di
...
@@ -168,6 +169,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di
ensemble
=
[]
ensemble
=
[]
for
state
in
states
:
for
state
in
states
:
model
=
build_model
(
args
,
src_dict
,
dst_dict
)
model
=
build_model
(
args
,
src_dict
,
dst_dict
)
state
[
'model'
]
=
model
.
upgrade_state_dict
(
state
[
'model'
])
model
.
load_state_dict
(
state
[
'model'
])
model
.
load_state_dict
(
state
[
'model'
])
ensemble
.
append
(
model
)
ensemble
.
append
(
model
)
return
ensemble
,
args
return
ensemble
,
args
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment