Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
e21901e8
Commit
e21901e8
authored
Oct 31, 2017
by
Myle Ott
Browse files
Fix interactive.py
parent
8f9dd964
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
12 deletions
+17
-12
fairseq/utils.py
fairseq/utils.py
+12
-3
generate.py
generate.py
+1
-1
interactive.py
interactive.py
+4
-8
No files found.
fairseq/utils.py
View file @
e21901e8
...
...
@@ -14,7 +14,7 @@ import traceback
from
torch.autograd
import
Variable
from
torch.serialization
import
default_restore_location
from
fairseq
import
criterions
,
models
,
tokenizer
from
fairseq
import
criterions
,
data
,
models
,
tokenizer
def
parse_args_and_arch
(
parser
):
...
...
@@ -117,7 +117,12 @@ def _upgrade_state_dict(state):
return
state
def
load_ensemble_for_inference
(
filenames
,
src_dict
,
dst_dict
):
def
load_ensemble_for_inference
(
filenames
,
src_dict
=
None
,
dst_dict
=
None
,
data_dir
=
None
):
"""Load an ensemble of models for inference.
The source and target dictionaries can be given explicitly, or loaded from
the `data_dir` directory.
"""
# load model architectures and weights
states
=
[]
for
filename
in
filenames
:
...
...
@@ -129,13 +134,17 @@ def load_ensemble_for_inference(filenames, src_dict, dst_dict):
args
=
states
[
0
][
'args'
]
args
=
_upgrade_args
(
args
)
if
src_dict
is
None
or
dst_dict
is
None
:
assert
data_dir
is
not
None
src_dict
,
dst_dict
=
data
.
load_dictionaries
(
data_dir
,
args
.
source_lang
,
args
.
target_lang
)
# build ensemble
ensemble
=
[]
for
state
in
states
:
model
=
build_model
(
args
,
src_dict
,
dst_dict
)
model
.
load_state_dict
(
state
[
'model'
])
ensemble
.
append
(
model
)
return
ensemble
return
ensemble
,
args
def
_upgrade_args
(
args
):
...
...
generate.py
View file @
e21901e8
...
...
@@ -41,7 +41,7 @@ def main():
# Load ensemble
print
(
'| loading model(s) from {}'
.
format
(
', '
.
join
(
args
.
path
)))
models
=
utils
.
load_ensemble_for_inference
(
args
.
path
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
models
,
_
=
utils
.
load_ensemble_for_inference
(
args
.
path
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
src
,
len
(
dataset
.
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
dst
,
len
(
dataset
.
dst_dict
)))
...
...
interactive.py
View file @
e21901e8
...
...
@@ -26,17 +26,13 @@ def main():
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
# Load dictionaries
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
args
.
source_lang
,
args
.
target_lang
,
_
=
data
.
infer_language_pair
(
args
.
data
,
[
'test'
])
src_dict
,
dst_dict
=
data
.
load_dictionaries
(
args
.
data
,
args
.
source_lang
,
args
.
target_lang
)
# Load ensemble
print
(
'| loading model(s) from {}'
.
format
(
', '
.
join
(
args
.
path
)))
models
=
utils
.
load_ensemble_for_inference
(
args
.
path
,
src_dict
,
dst_dict
)
models
,
model_args
=
utils
.
load_ensemble_for_inference
(
args
.
path
,
data_dir
=
args
.
data
)
src_dict
,
dst_dict
=
models
[
0
].
src_dict
,
models
[
0
].
dst_dict
print
(
'| [{}] dictionary: {} types'
.
format
(
args
.
source_lang
,
len
(
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
args
.
target_lang
,
len
(
dst_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
model_
args
.
source_lang
,
len
(
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
model_
args
.
target_lang
,
len
(
dst_dict
)))
# Optimize ensemble for generation
for
model
in
models
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment