Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
8f058ea0
Commit
8f058ea0
authored
Oct 11, 2017
by
Sergey Edunov
Browse files
Don't generate during training, add --quiet to generate.py
parent
a8260d52
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
13 additions
and
33 deletions
+13
-33
fairseq/data.py
fairseq/data.py
+5
-5
fairseq/options.py
fairseq/options.py
+2
-0
fairseq/utils.py
fairseq/utils.py
+2
-2
generate.py
generate.py
+3
-1
train.py
train.py
+1
-25
No files found.
fairseq/data.py
View file @
8f058ea0
...
...
@@ -17,7 +17,7 @@ from fairseq.dictionary import Dictionary
from
fairseq.indexed_dataset
import
IndexedDataset
,
IndexedInMemoryDataset
def
load_with_check
(
path
,
src
=
None
,
dst
=
None
):
def
load_with_check
(
path
,
load_splits
,
src
=
None
,
dst
=
None
):
"""Loads the train, valid, and test sets from the specified folder
and check that training files exist."""
...
...
@@ -43,12 +43,12 @@ def load_with_check(path, src=None, dst=None):
else
:
raise
ValueError
(
'training file not found for {}-{}'
.
format
(
src
,
dst
))
dataset
=
load
(
path
,
src
,
dst
)
dataset
=
load
(
path
,
load_splits
,
src
,
dst
)
return
dataset
def
load
(
path
,
src
,
dst
):
"""Loads
the train, valid, and test sets from the specified folder
."""
def
load
(
path
,
load_splits
,
src
,
dst
):
"""Loads
specified data splits (e.g. test, train or valid) from the path
."""
langcode
=
'{}-{}'
.
format
(
src
,
dst
)
...
...
@@ -59,7 +59,7 @@ def load(path, src, dst):
dst_dict
=
Dictionary
.
load
(
fmt_path
(
'dict.{}.txt'
,
dst
))
dataset
=
LanguageDatasets
(
src
,
dst
,
src_dict
,
dst_dict
)
for
split
in
[
'train'
,
'valid'
,
'test'
]
:
for
split
in
load_splits
:
for
k
in
itertools
.
count
():
prefix
=
"{}{}"
.
format
(
split
,
k
if
k
>
0
else
''
)
src_path
=
fmt_path
(
'{}.{}.{}'
,
prefix
,
langcode
,
src
)
...
...
fairseq/options.py
View file @
8f058ea0
...
...
@@ -104,6 +104,8 @@ def add_generation_args(parser):
help
=
'length penalty: <1.0 favors shorter, >1.0 favors longer sentences'
)
group
.
add_argument
(
'--unk-replace-dict'
,
default
=
''
,
type
=
str
,
help
=
'performs unk word replacement'
)
group
.
add_argument
(
'--quiet'
,
action
=
'store_true'
,
help
=
'Only print final scores'
)
return
group
...
...
fairseq/utils.py
View file @
8f058ea0
...
...
@@ -94,7 +94,7 @@ def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None):
return
epoch
,
batch_offset
def
load_ensemble_for_inference
(
filenames
,
data_path
):
def
load_ensemble_for_inference
(
filenames
,
data_path
,
split
):
# load model architectures and weights
states
=
[]
for
filename
in
filenames
:
...
...
@@ -106,7 +106,7 @@ def load_ensemble_for_inference(filenames, data_path):
# load dataset
args
=
states
[
0
][
'args'
]
dataset
=
data
.
load
(
data_path
,
args
.
source_lang
,
args
.
target_lang
)
dataset
=
data
.
load
(
data_path
,
[
split
],
args
.
source_lang
,
args
.
target_lang
)
# build models
ensemble
=
[]
...
...
generate.py
View file @
8f058ea0
...
...
@@ -38,7 +38,7 @@ def main():
# Load model and dataset
print
(
'| loading model(s) from {}'
.
format
(
', '
.
join
(
args
.
path
)))
models
,
dataset
=
utils
.
load_ensemble_for_inference
(
args
.
path
,
args
.
data
)
models
,
dataset
=
utils
.
load_ensemble_for_inference
(
args
.
path
,
args
.
data
,
args
.
gen_subset
)
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
src
,
len
(
dataset
.
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
dst
,
len
(
dataset
.
dst_dict
)))
...
...
@@ -81,6 +81,8 @@ def main():
bpe_symbol
=
'@@ '
if
args
.
remove_bpe
else
None
def
display_hypotheses
(
id
,
src
,
orig
,
ref
,
hypos
):
if
args
.
quiet
:
return
id_str
=
''
if
id
is
None
else
'-{}'
.
format
(
id
)
src_str
=
to_sentence
(
dataset
.
src_dict
,
src
,
bpe_symbol
)
print
(
'S{}
\t
{}'
.
format
(
id_str
,
src_str
))
...
...
train.py
View file @
8f058ea0
...
...
@@ -29,9 +29,6 @@ def main():
dataset_args
.
add_argument
(
'--valid-subset'
,
default
=
'valid'
,
metavar
=
'SPLIT'
,
help
=
'comma separated list ofdata subsets '
' to use for validation (train, valid, valid1,test, test1)'
)
dataset_args
.
add_argument
(
'--test-subset'
,
default
=
'test'
,
metavar
=
'SPLIT'
,
help
=
'comma separated list ofdata subset '
'to use for testing (train, valid, test)'
)
options
.
add_optimization_args
(
parser
)
options
.
add_checkpoint_args
(
parser
)
options
.
add_model_args
(
parser
)
...
...
@@ -48,7 +45,7 @@ def main():
torch
.
manual_seed
(
args
.
seed
)
# Load dataset
dataset
=
data
.
load_with_check
(
args
.
data
,
args
.
source_lang
,
args
.
target_lang
)
dataset
=
data
.
load_with_check
(
args
.
data
,
[
'train'
,
'valid'
],
args
.
source_lang
,
args
.
target_lang
)
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
# record inferred languages in args, so that it's saved in checkpoints
args
.
source_lang
,
args
.
target_lang
=
dataset
.
src
,
dataset
.
dst
...
...
@@ -100,13 +97,6 @@ def main():
train_meter
.
stop
()
print
(
'| done training in {:.1f} seconds'
.
format
(
train_meter
.
sum
))
# Generate on test set and compute BLEU score
for
beam
in
[
1
,
5
,
10
,
20
]:
for
subset
in
args
.
test_subset
.
split
(
','
):
scorer
=
score_test
(
args
,
trainer
.
get_model
(),
dataset
,
subset
,
beam
,
cuda_device
=
(
0
if
num_gpus
>
0
else
None
))
print
(
'| Test on {} with beam={}: {}'
.
format
(
subset
,
beam
,
scorer
.
result_string
()))
# Stop multiprocessing
trainer
.
stop
()
...
...
@@ -192,19 +182,5 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
return
val_loss
def
score_test
(
args
,
model
,
dataset
,
subset
,
beam
,
cuda_device
):
"""Evaluate the model on the test set and return the BLEU scorer."""
translator
=
SequenceGenerator
([
model
],
dataset
.
dst_dict
,
beam_size
=
beam
)
if
torch
.
cuda
.
is_available
():
translator
.
cuda
()
scorer
=
bleu
.
Scorer
(
dataset
.
dst_dict
.
pad
(),
dataset
.
dst_dict
.
eos
(),
dataset
.
dst_dict
.
unk
())
itr
=
dataset
.
dataloader
(
subset
,
batch_size
=
4
,
max_positions
=
args
.
max_positions
)
for
_
,
_
,
ref
,
hypos
in
translator
.
generate_batched_itr
(
itr
,
cuda_device
=
cuda_device
):
scorer
.
add
(
ref
.
int
().
cpu
(),
hypos
[
0
][
'tokens'
].
int
().
cpu
())
return
scorer
if
__name__
==
'__main__'
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment