Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
e4f51e18
Commit
e4f51e18
authored
Aug 03, 2018
by
alexeib
Committed by
Myle Ott
Sep 03, 2018
Browse files
load args from model for eval_lm
parent
45082e48
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
9 deletions
+16
-9
eval_lm.py
eval_lm.py
+14
-9
fairseq/models/transformer.py
fairseq/models/transformer.py
+2
-0
No files found.
eval_lm.py
View file @
e4f51e18
...
...
@@ -14,23 +14,28 @@ from fairseq.meters import StopwatchMeter, TimeMeter
from
fairseq.sequence_scorer
import
SequenceScorer
def
main
(
args
):
assert
args
.
path
is
not
None
,
'--path required for evaluation!'
def
main
(
parsed_
args
):
assert
parsed_
args
.
path
is
not
None
,
'--path required for evaluation!'
args
.
tokens_per_sample
=
getattr
(
args
,
'tokens_per_sample'
,
1024
)
print
(
parsed_args
)
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
parsed_args
.
cpu
task
=
tasks
.
setup_task
(
parsed_args
)
# Load ensemble
print
(
'| loading model(s) from {}'
.
format
(
parsed_args
.
path
))
models
,
args
=
utils
.
load_ensemble_for_inference
(
parsed_args
.
path
.
split
(
':'
),
task
)
args
.
__dict__
.
update
(
parsed_args
.
__dict__
)
print
(
args
)
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
task
.
args
=
args
# Load dataset splits
task
=
tasks
.
setup_task
(
args
)
task
.
load_dataset
(
args
.
gen_subset
)
print
(
'| {} {} {} examples'
.
format
(
args
.
data
,
args
.
gen_subset
,
len
(
task
.
dataset
(
args
.
gen_subset
))))
# Load ensemble
print
(
'| loading model(s) from {}'
.
format
(
args
.
path
))
models
,
_
=
utils
.
load_ensemble_for_inference
(
args
.
path
.
split
(
':'
),
task
)
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for
model
in
models
:
model
.
make_generation_fast_
()
...
...
fairseq/models/transformer.py
View file @
e4f51e18
...
...
@@ -193,6 +193,8 @@ class TransformerLanguageModel(FairseqLanguageModel):
else
:
embed_tokens
=
Embedding
(
len
(
task
.
dictionary
),
args
.
decoder_embed_dim
,
task
.
dictionary
.
pad
())
print
(
args
)
decoder
=
TransformerDecoder
(
args
,
task
.
dictionary
,
embed_tokens
,
no_encoder_attn
=
True
)
return
TransformerLanguageModel
(
decoder
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment