Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
8f058ea0
"...text-generation-inference.git" did not exist on "e36dfaa8de2d9a9fa67eeed5ce64fd5949916c99"
Commit
8f058ea0
authored
Oct 11, 2017
by
Sergey Edunov
Browse files
Don't generate during training, add --quiet to generate.py
parent
a8260d52
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
13 additions
and
33 deletions
+13
-33
fairseq/data.py
fairseq/data.py
+5
-5
fairseq/options.py
fairseq/options.py
+2
-0
fairseq/utils.py
fairseq/utils.py
+2
-2
generate.py
generate.py
+3
-1
train.py
train.py
+1
-25
No files found.
fairseq/data.py
View file @
8f058ea0
...
@@ -17,7 +17,7 @@ from fairseq.dictionary import Dictionary
...
@@ -17,7 +17,7 @@ from fairseq.dictionary import Dictionary
from
fairseq.indexed_dataset
import
IndexedDataset
,
IndexedInMemoryDataset
from
fairseq.indexed_dataset
import
IndexedDataset
,
IndexedInMemoryDataset
def
load_with_check
(
path
,
src
=
None
,
dst
=
None
):
def
load_with_check
(
path
,
load_splits
,
src
=
None
,
dst
=
None
):
"""Loads the train, valid, and test sets from the specified folder
"""Loads the train, valid, and test sets from the specified folder
and check that training files exist."""
and check that training files exist."""
...
@@ -43,12 +43,12 @@ def load_with_check(path, src=None, dst=None):
...
@@ -43,12 +43,12 @@ def load_with_check(path, src=None, dst=None):
else
:
else
:
raise
ValueError
(
'training file not found for {}-{}'
.
format
(
src
,
dst
))
raise
ValueError
(
'training file not found for {}-{}'
.
format
(
src
,
dst
))
dataset
=
load
(
path
,
src
,
dst
)
dataset
=
load
(
path
,
load_splits
,
src
,
dst
)
return
dataset
return
dataset
def
load
(
path
,
src
,
dst
):
def
load
(
path
,
load_splits
,
src
,
dst
):
"""Loads
the train, valid, and test sets from the specified folder
."""
"""Loads
specified data splits (e.g. test, train or valid) from the path
."""
langcode
=
'{}-{}'
.
format
(
src
,
dst
)
langcode
=
'{}-{}'
.
format
(
src
,
dst
)
...
@@ -59,7 +59,7 @@ def load(path, src, dst):
...
@@ -59,7 +59,7 @@ def load(path, src, dst):
dst_dict
=
Dictionary
.
load
(
fmt_path
(
'dict.{}.txt'
,
dst
))
dst_dict
=
Dictionary
.
load
(
fmt_path
(
'dict.{}.txt'
,
dst
))
dataset
=
LanguageDatasets
(
src
,
dst
,
src_dict
,
dst_dict
)
dataset
=
LanguageDatasets
(
src
,
dst
,
src_dict
,
dst_dict
)
for
split
in
[
'train'
,
'valid'
,
'test'
]
:
for
split
in
load_splits
:
for
k
in
itertools
.
count
():
for
k
in
itertools
.
count
():
prefix
=
"{}{}"
.
format
(
split
,
k
if
k
>
0
else
''
)
prefix
=
"{}{}"
.
format
(
split
,
k
if
k
>
0
else
''
)
src_path
=
fmt_path
(
'{}.{}.{}'
,
prefix
,
langcode
,
src
)
src_path
=
fmt_path
(
'{}.{}.{}'
,
prefix
,
langcode
,
src
)
...
...
fairseq/options.py
View file @
8f058ea0
...
@@ -104,6 +104,8 @@ def add_generation_args(parser):
...
@@ -104,6 +104,8 @@ def add_generation_args(parser):
help
=
'length penalty: <1.0 favors shorter, >1.0 favors longer sentences'
)
help
=
'length penalty: <1.0 favors shorter, >1.0 favors longer sentences'
)
group
.
add_argument
(
'--unk-replace-dict'
,
default
=
''
,
type
=
str
,
group
.
add_argument
(
'--unk-replace-dict'
,
default
=
''
,
type
=
str
,
help
=
'performs unk word replacement'
)
help
=
'performs unk word replacement'
)
group
.
add_argument
(
'--quiet'
,
action
=
'store_true'
,
help
=
'Only print final scores'
)
return
group
return
group
...
...
fairseq/utils.py
View file @
8f058ea0
...
@@ -94,7 +94,7 @@ def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None):
...
@@ -94,7 +94,7 @@ def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None):
return
epoch
,
batch_offset
return
epoch
,
batch_offset
def
load_ensemble_for_inference
(
filenames
,
data_path
):
def
load_ensemble_for_inference
(
filenames
,
data_path
,
split
):
# load model architectures and weights
# load model architectures and weights
states
=
[]
states
=
[]
for
filename
in
filenames
:
for
filename
in
filenames
:
...
@@ -106,7 +106,7 @@ def load_ensemble_for_inference(filenames, data_path):
...
@@ -106,7 +106,7 @@ def load_ensemble_for_inference(filenames, data_path):
# load dataset
# load dataset
args
=
states
[
0
][
'args'
]
args
=
states
[
0
][
'args'
]
dataset
=
data
.
load
(
data_path
,
args
.
source_lang
,
args
.
target_lang
)
dataset
=
data
.
load
(
data_path
,
[
split
],
args
.
source_lang
,
args
.
target_lang
)
# build models
# build models
ensemble
=
[]
ensemble
=
[]
...
...
generate.py
View file @
8f058ea0
...
@@ -38,7 +38,7 @@ def main():
...
@@ -38,7 +38,7 @@ def main():
# Load model and dataset
# Load model and dataset
print
(
'| loading model(s) from {}'
.
format
(
', '
.
join
(
args
.
path
)))
print
(
'| loading model(s) from {}'
.
format
(
', '
.
join
(
args
.
path
)))
models
,
dataset
=
utils
.
load_ensemble_for_inference
(
args
.
path
,
args
.
data
)
models
,
dataset
=
utils
.
load_ensemble_for_inference
(
args
.
path
,
args
.
data
,
args
.
gen_subset
)
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
src
,
len
(
dataset
.
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
src
,
len
(
dataset
.
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
dst
,
len
(
dataset
.
dst_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
dst
,
len
(
dataset
.
dst_dict
)))
...
@@ -81,6 +81,8 @@ def main():
...
@@ -81,6 +81,8 @@ def main():
bpe_symbol
=
'@@ '
if
args
.
remove_bpe
else
None
bpe_symbol
=
'@@ '
if
args
.
remove_bpe
else
None
def
display_hypotheses
(
id
,
src
,
orig
,
ref
,
hypos
):
def
display_hypotheses
(
id
,
src
,
orig
,
ref
,
hypos
):
if
args
.
quiet
:
return
id_str
=
''
if
id
is
None
else
'-{}'
.
format
(
id
)
id_str
=
''
if
id
is
None
else
'-{}'
.
format
(
id
)
src_str
=
to_sentence
(
dataset
.
src_dict
,
src
,
bpe_symbol
)
src_str
=
to_sentence
(
dataset
.
src_dict
,
src
,
bpe_symbol
)
print
(
'S{}
\t
{}'
.
format
(
id_str
,
src_str
))
print
(
'S{}
\t
{}'
.
format
(
id_str
,
src_str
))
...
...
train.py
View file @
8f058ea0
...
@@ -29,9 +29,6 @@ def main():
...
@@ -29,9 +29,6 @@ def main():
dataset_args
.
add_argument
(
'--valid-subset'
,
default
=
'valid'
,
metavar
=
'SPLIT'
,
dataset_args
.
add_argument
(
'--valid-subset'
,
default
=
'valid'
,
metavar
=
'SPLIT'
,
help
=
'comma separated list ofdata subsets '
help
=
'comma separated list ofdata subsets '
' to use for validation (train, valid, valid1,test, test1)'
)
' to use for validation (train, valid, valid1,test, test1)'
)
dataset_args
.
add_argument
(
'--test-subset'
,
default
=
'test'
,
metavar
=
'SPLIT'
,
help
=
'comma separated list ofdata subset '
'to use for testing (train, valid, test)'
)
options
.
add_optimization_args
(
parser
)
options
.
add_optimization_args
(
parser
)
options
.
add_checkpoint_args
(
parser
)
options
.
add_checkpoint_args
(
parser
)
options
.
add_model_args
(
parser
)
options
.
add_model_args
(
parser
)
...
@@ -48,7 +45,7 @@ def main():
...
@@ -48,7 +45,7 @@ def main():
torch
.
manual_seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
# Load dataset
# Load dataset
dataset
=
data
.
load_with_check
(
args
.
data
,
args
.
source_lang
,
args
.
target_lang
)
dataset
=
data
.
load_with_check
(
args
.
data
,
[
'train'
,
'valid'
],
args
.
source_lang
,
args
.
target_lang
)
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
# record inferred languages in args, so that it's saved in checkpoints
# record inferred languages in args, so that it's saved in checkpoints
args
.
source_lang
,
args
.
target_lang
=
dataset
.
src
,
dataset
.
dst
args
.
source_lang
,
args
.
target_lang
=
dataset
.
src
,
dataset
.
dst
...
@@ -100,13 +97,6 @@ def main():
...
@@ -100,13 +97,6 @@ def main():
train_meter
.
stop
()
train_meter
.
stop
()
print
(
'| done training in {:.1f} seconds'
.
format
(
train_meter
.
sum
))
print
(
'| done training in {:.1f} seconds'
.
format
(
train_meter
.
sum
))
# Generate on test set and compute BLEU score
for
beam
in
[
1
,
5
,
10
,
20
]:
for
subset
in
args
.
test_subset
.
split
(
','
):
scorer
=
score_test
(
args
,
trainer
.
get_model
(),
dataset
,
subset
,
beam
,
cuda_device
=
(
0
if
num_gpus
>
0
else
None
))
print
(
'| Test on {} with beam={}: {}'
.
format
(
subset
,
beam
,
scorer
.
result_string
()))
# Stop multiprocessing
# Stop multiprocessing
trainer
.
stop
()
trainer
.
stop
()
...
@@ -192,19 +182,5 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
...
@@ -192,19 +182,5 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
return
val_loss
return
val_loss
def
score_test
(
args
,
model
,
dataset
,
subset
,
beam
,
cuda_device
):
"""Evaluate the model on the test set and return the BLEU scorer."""
translator
=
SequenceGenerator
([
model
],
dataset
.
dst_dict
,
beam_size
=
beam
)
if
torch
.
cuda
.
is_available
():
translator
.
cuda
()
scorer
=
bleu
.
Scorer
(
dataset
.
dst_dict
.
pad
(),
dataset
.
dst_dict
.
eos
(),
dataset
.
dst_dict
.
unk
())
itr
=
dataset
.
dataloader
(
subset
,
batch_size
=
4
,
max_positions
=
args
.
max_positions
)
for
_
,
_
,
ref
,
hypos
in
translator
.
generate_batched_itr
(
itr
,
cuda_device
=
cuda_device
):
scorer
.
add
(
ref
.
int
().
cpu
(),
hypos
[
0
][
'tokens'
].
int
().
cpu
())
return
scorer
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment