Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
e4f51e18
"vscode:/vscode.git/clone" did not exist on "c944f0651f679728d4ec7b6488120ac49c2f1315"
Commit
e4f51e18
authored
Aug 03, 2018
by
alexeib
Committed by
Myle Ott
Sep 03, 2018
Browse files
load args from model for eval_lm
parent
45082e48
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
9 deletions
+16
-9
eval_lm.py
eval_lm.py
+14
-9
fairseq/models/transformer.py
fairseq/models/transformer.py
+2
-0
No files found.
eval_lm.py
View file @
e4f51e18
...
...
@@ -14,23 +14,28 @@ from fairseq.meters import StopwatchMeter, TimeMeter
from
fairseq.sequence_scorer
import
SequenceScorer
def
main
(
args
):
assert
args
.
path
is
not
None
,
'--path required for evaluation!'
def
main
(
parsed_
args
):
assert
parsed_
args
.
path
is
not
None
,
'--path required for evaluation!'
args
.
tokens_per_sample
=
getattr
(
args
,
'tokens_per_sample'
,
1024
)
print
(
parsed_args
)
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
parsed_args
.
cpu
task
=
tasks
.
setup_task
(
parsed_args
)
# Load ensemble
print
(
'| loading model(s) from {}'
.
format
(
parsed_args
.
path
))
models
,
args
=
utils
.
load_ensemble_for_inference
(
parsed_args
.
path
.
split
(
':'
),
task
)
args
.
__dict__
.
update
(
parsed_args
.
__dict__
)
print
(
args
)
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
task
.
args
=
args
# Load dataset splits
task
=
tasks
.
setup_task
(
args
)
task
.
load_dataset
(
args
.
gen_subset
)
print
(
'| {} {} {} examples'
.
format
(
args
.
data
,
args
.
gen_subset
,
len
(
task
.
dataset
(
args
.
gen_subset
))))
# Load ensemble
print
(
'| loading model(s) from {}'
.
format
(
args
.
path
))
models
,
_
=
utils
.
load_ensemble_for_inference
(
args
.
path
.
split
(
':'
),
task
)
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for
model
in
models
:
model
.
make_generation_fast_
()
...
...
fairseq/models/transformer.py
View file @
e4f51e18
...
...
@@ -193,6 +193,8 @@ class TransformerLanguageModel(FairseqLanguageModel):
else
:
embed_tokens
=
Embedding
(
len
(
task
.
dictionary
),
args
.
decoder_embed_dim
,
task
.
dictionary
.
pad
())
print
(
args
)
decoder
=
TransformerDecoder
(
args
,
task
.
dictionary
,
embed_tokens
,
no_encoder_attn
=
True
)
return
TransformerLanguageModel
(
decoder
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment