Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
8f9dd964
Commit
8f9dd964
authored
Oct 31, 2017
by
Myle Ott
Browse files
Improvements to data loader
parent
97d7fcb9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
95 additions
and
99 deletions
+95
-99
fairseq/data.py
fairseq/data.py
+78
-79
generate.py
generate.py
+4
-4
interactive.py
interactive.py
+8
-10
train.py
train.py
+5
-6
No files found.
fairseq/data.py
View file @
8f9dd964
...
...
@@ -18,61 +18,65 @@ from fairseq.dictionary import Dictionary
from
fairseq.indexed_dataset
import
IndexedDataset
,
IndexedInMemoryDataset
def
load_with_check
(
path
,
load_splits
,
src
=
None
,
dst
=
None
):
"""Loads specified data splits (e.g., test, train or valid) from the
specified folder and check that files exist."""
def
find_language_pair
(
files
):
for
split
in
load_splits
:
for
filename
in
files
:
def
infer_language_pair
(
path
,
splits
):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src
,
dst
=
None
,
None
for
filename
in
os
.
listdir
(
path
):
parts
=
filename
.
split
(
'.'
)
for
split
in
splits
:
if
parts
[
0
]
==
split
and
parts
[
-
1
]
==
'idx'
:
return
parts
[
1
].
split
(
'-'
)
src
,
dst
=
parts
[
1
].
split
(
'-'
)
break
return
src
,
dst
def
split_exists
(
split
,
src
,
dst
):
filename
=
'{0}.{1}-{2}.{1}.idx'
.
format
(
split
,
src
,
dst
)
return
os
.
path
.
exists
(
os
.
path
.
join
(
path
,
filename
))
def
load_dictionaries
(
path
,
src_lang
,
dst_lang
):
"""Load dictionaries for a given language pair."""
src_dict
=
Dictionary
.
load
(
os
.
path
.
join
(
path
,
'dict.{}.txt'
.
format
(
src_lang
)))
dst_dict
=
Dictionary
.
load
(
os
.
path
.
join
(
path
,
'dict.{}.txt'
.
format
(
dst_lang
)))
return
src_dict
,
dst_dict
def
load_dataset
(
path
,
load_splits
,
src
=
None
,
dst
=
None
):
"""Loads specified data splits (e.g., test, train or valid) from the
specified folder and check that files exist."""
if
src
is
None
and
dst
is
None
:
# find language pair automatically
src
,
dst
=
find_language_pair
(
os
.
listdir
(
path
))
if
not
split_exists
(
load_splits
[
0
],
src
,
dst
):
# try reversing src and dst
src
,
dst
=
dst
,
src
src
,
dst
=
infer_language_pair
(
path
,
load_splits
)
def
all_splits_exist
(
src
,
dst
):
for
split
in
load_splits
:
if
not
split_exists
(
load_splits
[
0
],
src
,
dst
):
raise
ValueError
(
'Data split not found: {}-{} ({})'
.
format
(
src
,
dst
,
split
))
dataset
=
load
(
path
,
load_splits
,
src
,
dst
)
return
dataset
def
load
(
path
,
load_splits
,
src
,
dst
):
"""Loads specified data splits (e.g. test, train or valid) from the path."""
filename
=
'{0}.{1}-{2}.{1}.idx'
.
format
(
split
,
src
,
dst
)
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
path
,
filename
)):
return
False
return
True
# infer langcode
if
all_splits_exist
(
src
,
dst
):
langcode
=
'{}-{}'
.
format
(
src
,
dst
)
elif
all_splits_exist
(
dst
,
src
):
langcode
=
'{}-{}'
.
format
(
dst
,
src
)
else
:
raise
Exception
(
'Dataset cannot be loaded from path: '
+
path
)
src_dict
,
dst_dict
=
load_dictionaries
(
path
,
src
,
dst
)
dataset
=
LanguageDatasets
(
src
,
dst
,
src_dict
,
dst_dict
)
def
fmt_path
(
fmt
,
*
args
):
return
os
.
path
.
join
(
path
,
fmt
.
format
(
*
args
))
src_dict
=
Dictionary
.
load
(
fmt_path
(
'dict.{}.txt'
,
src
))
dst_dict
=
Dictionary
.
load
(
fmt_path
(
'dict.{}.txt'
,
dst
))
dataset
=
LanguageDatasets
(
src
,
dst
,
src_dict
,
dst_dict
)
for
split
in
load_splits
:
for
k
in
itertools
.
count
():
prefix
=
"{}{}"
.
format
(
split
,
k
if
k
>
0
else
''
)
src_path
=
fmt_path
(
'{}.{}.{}'
,
prefix
,
langcode
,
src
)
dst_path
=
fmt_path
(
'{}.{}.{}'
,
prefix
,
langcode
,
dst
)
if
not
IndexedInMemoryDataset
.
exists
(
src_path
):
break
dataset
.
splits
[
prefix
]
=
LanguagePairDataset
(
IndexedInMemoryDataset
(
src_path
),
IndexedInMemoryDataset
(
fm
t_path
(
'{}.{}.{}'
,
prefix
,
langcode
,
dst
)
),
IndexedInMemoryDataset
(
ds
t_path
),
pad_idx
=
dataset
.
src_dict
.
pad
(),
eos_idx
=
dataset
.
src_dict
.
eos
(),
)
...
...
@@ -92,13 +96,11 @@ class LanguageDatasets(object):
assert
self
.
src_dict
.
eos
()
==
self
.
dst_dict
.
eos
()
assert
self
.
src_dict
.
unk
()
==
self
.
dst_dict
.
unk
()
def
dataloader
(
self
,
split
,
batch_size
=
1
,
num_workers
=
0
,
max_tokens
=
None
,
seed
=
None
,
epoch
=
1
,
sample_without_replacement
=
0
,
max_positions
=
(
1024
,
1024
),
skip_invalid_size_inputs_valid_test
=
False
,
def
train_dataloader
(
self
,
split
,
num_workers
=
0
,
max_tokens
=
None
,
max_positions
=
(
1024
,
1024
),
seed
=
None
,
epoch
=
1
,
sample_without_replacement
=
0
,
sort_by_source_size
=
False
):
dataset
=
self
.
splits
[
split
]
if
split
.
startswith
(
'train'
):
with
numpy_seed
(
seed
):
batch_sampler
=
shuffled_batches_by_size
(
dataset
.
src
,
dataset
.
dst
,
...
...
@@ -106,22 +108,23 @@ class LanguageDatasets(object):
sample
=
sample_without_replacement
,
max_positions
=
max_positions
,
sort_by_source_size
=
sort_by_source_size
)
elif
split
.
startswith
(
'valid'
):
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
num_workers
=
num_workers
,
collate_fn
=
dataset
.
collater
,
batch_sampler
=
batch_sampler
)
def
eval_dataloader
(
self
,
split
,
num_workers
=
0
,
batch_size
=
1
,
max_tokens
=
None
,
consider_dst_sizes
=
True
,
max_positions
=
(
1024
,
1024
),
skip_invalid_size_inputs_valid_test
=
False
):
dataset
=
self
.
splits
[
split
]
dst_dataset
=
dataset
.
dst
if
consider_dst_sizes
else
None
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
batch_size
,
max_tokens
,
dst
=
dataset
.
dst
,
dataset
.
src
,
dataset
.
dst
,
batch_size
,
max_tokens
,
max_positions
=
max_positions
,
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
))
else
:
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
batch_size
,
max_tokens
,
max_positions
=
max_positions
,
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
))
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
num_workers
=
num_workers
,
collate_fn
=
dataset
.
collater
,
batch_sampler
=
batch_sampler
,
)
dataset
,
num_workers
=
num_workers
,
collate_fn
=
dataset
.
collater
,
batch_sampler
=
batch_sampler
)
def
skip_group_enumerator
(
it
,
ngpus
,
offset
=
0
):
...
...
@@ -174,14 +177,15 @@ class LanguagePairDataset(object):
return
LanguagePairDataset
.
collate_tokens
(
[
s
[
key
]
for
s
in
samples
],
pad_idx
,
eos_idx
,
left_pad
,
move_eos_to_beginning
)
ntokens
=
sum
(
len
(
s
[
'target'
])
for
s
in
samples
)
return
{
'id'
:
torch
.
LongTensor
([
s
[
'id'
].
item
()
for
s
in
samples
]),
'src_tokens'
:
merge
(
'source'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_SOURCE
),
# we create a shifted version of targets for feeding the previous
# output token(s) into the next decoder step
'input_tokens'
:
merge
(
'target'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
,
move_eos_to_beginning
=
True
),
'target'
:
merge
(
'target'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
),
'src_tokens'
:
merge
(
'source'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_SOURCE
),
'ntokens'
:
ntokens
,
'ntokens'
:
sum
(
len
(
s
[
'target'
])
for
s
in
samples
),
}
@
staticmethod
...
...
@@ -218,18 +222,14 @@ def _valid_size(src_size, dst_size, max_positions):
return
True
def
batches_by_size
(
src
,
batch_size
=
None
,
max_tokens
=
None
,
dst
=
None
,
def
batches_by_size
(
src
,
dst
,
batch_size
=
None
,
max_tokens
=
None
,
max_positions
=
(
1024
,
1024
),
ignore_invalid_inputs
=
False
):
"""Returns batches of indices sorted by size. Sequences of different lengths
are not allowed in the same batch."""
assert
isinstance
(
src
,
IndexedDataset
)
assert
dst
is
None
or
isinstance
(
dst
,
IndexedDataset
)
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
if
max_tokens
is
None
:
max_tokens
=
float
(
'Inf'
)
sizes
=
src
.
sizes
indices
=
np
.
argsort
(
sizes
,
kind
=
'mergesort'
)
if
dst
is
not
None
:
sizes
=
np
.
maximum
(
sizes
,
dst
.
sizes
)
indices
=
np
.
argsort
(
src
.
sizes
,
kind
=
'mergesort'
)
batch
=
[]
...
...
@@ -238,7 +238,7 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
return
False
if
len
(
batch
)
==
batch_size
:
return
True
if
sizes
[
batch
[
0
]]
!=
sizes
[
next_idx
]:
if
src
.
sizes
[
batch
[
0
]]
!=
src
.
sizes
[
next_idx
]:
return
True
if
num_tokens
>=
max_tokens
:
return
True
...
...
@@ -247,21 +247,20 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
cur_max_size
=
0
ignored
=
[]
for
idx
in
indices
:
if
not
_valid_size
(
src
.
sizes
[
idx
],
None
if
dst
is
None
else
dst
.
sizes
[
idx
],
max_positions
):
if
not
_valid_size
(
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
],
max_positions
):
if
ignore_invalid_inputs
:
ignored
.
append
(
idx
)
continue
raise
Exception
(
"Unable to handle input id {} of size {} / {}."
.
format
(
idx
,
src
.
sizes
[
idx
],
"none"
if
dst
is
None
else
dst
.
sizes
[
idx
]))
raise
Exception
(
"Unable to handle input id {} of size {} / {}."
.
format
(
idx
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
]))
if
yield_batch
(
idx
,
cur_max_size
*
(
len
(
batch
)
+
1
)):
yield
batch
batch
=
[]
cur_max_size
=
0
batch
.
append
(
idx
)
cur_max_size
=
max
(
cur_max_size
,
sizes
[
idx
])
cur_max_size
=
max
(
cur_max_size
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
])
if
len
(
ignored
)
>
0
:
print
(
"Warning! {} samples are either too short or too long "
...
...
generate.py
View file @
8f9dd964
...
...
@@ -34,7 +34,7 @@ def main():
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
# Load dataset
dataset
=
data
.
load_
with_check
(
args
.
data
,
[
args
.
gen_subset
],
args
.
source_lang
,
args
.
target_lang
)
dataset
=
data
.
load_
dataset
(
args
.
data
,
[
args
.
gen_subset
],
args
.
source_lang
,
args
.
target_lang
)
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
# record inferred languages in args
args
.
source_lang
,
args
.
target_lang
=
dataset
.
src
,
dataset
.
dst
...
...
@@ -67,8 +67,8 @@ def main():
# Generate and compute BLEU score
scorer
=
bleu
.
Scorer
(
dataset
.
dst_dict
.
pad
(),
dataset
.
dst_dict
.
eos
(),
dataset
.
dst_dict
.
unk
())
max_positions
=
min
(
model
.
max_encoder_positions
()
for
model
in
models
)
itr
=
dataset
.
dataloader
(
args
.
gen_subset
,
batch_size
=
args
.
batch_size
,
max_positions
=
max_positions
,
itr
=
dataset
.
eval_
dataloader
(
args
.
gen_subset
,
batch_size
=
args
.
batch_size
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
num_sentences
=
0
with
progress_bar
(
itr
,
smoothing
=
0
,
leave
=
False
)
as
t
:
...
...
interactive.py
View file @
8f9dd964
...
...
@@ -26,19 +26,17 @@ def main():
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
# Load dataset
# TODO: load only dictionaries
dataset
=
data
.
load_with_check
(
args
.
data
,
[
'test'
],
args
.
source_lang
,
args
.
target_lang
)
# Load dictionaries
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
# record
infer
red
language
s in args
args
.
source_lang
,
args
.
target_lang
=
dataset
.
src
,
dataset
.
dst
args
.
source_lang
,
args
.
target_lang
,
_
=
data
.
infer
_
language
_pair
(
args
.
data
,
[
'test'
])
src_dict
,
dst_dict
=
data
.
load_dictionaries
(
args
.
data
,
args
.
source_lang
,
args
.
target_lang
)
# Load ensemble
print
(
'| loading model(s) from {}'
.
format
(
', '
.
join
(
args
.
path
)))
models
=
utils
.
load_ensemble_for_inference
(
args
.
path
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
models
=
utils
.
load_ensemble_for_inference
(
args
.
path
,
src_dict
,
dst_dict
)
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
src
,
len
(
dataset
.
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
dst
,
len
(
dataset
.
dst_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
args
.
source_lang
,
len
(
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
args
.
target_lang
,
len
(
dst_dict
)))
# Optimize ensemble for generation
for
model
in
models
:
...
...
@@ -60,7 +58,7 @@ def main():
print
(
'Type the input sentence and press return:'
)
for
src_str
in
sys
.
stdin
:
src_str
=
src_str
.
strip
()
src_tokens
=
tokenizer
.
Tokenizer
.
tokenize
(
src_str
,
dataset
.
src_dict
,
add_if_not_exist
=
False
).
long
()
src_tokens
=
tokenizer
.
Tokenizer
.
tokenize
(
src_str
,
src_dict
,
add_if_not_exist
=
False
).
long
()
if
use_cuda
:
src_tokens
=
src_tokens
.
cuda
()
translations
=
translator
.
generate
(
Variable
(
src_tokens
.
view
(
1
,
-
1
)))
...
...
@@ -74,7 +72,7 @@ def main():
src_str
=
src_str
,
alignment
=
hypo
[
'alignment'
].
int
().
cpu
(),
align_dict
=
align_dict
,
dst_dict
=
dataset
.
dst_dict
,
dst_dict
=
dst_dict
,
remove_bpe
=
args
.
remove_bpe
)
print
(
'A
\t
{}'
.
format
(
' '
.
join
(
map
(
str
,
alignment
))))
print
(
'H
\t
{}
\t
{}'
.
format
(
hypo
[
'score'
],
hypo_str
))
...
...
train.py
View file @
8f9dd964
...
...
@@ -34,7 +34,6 @@ def main():
options
.
add_model_args
(
parser
)
args
=
utils
.
parse_args_and_arch
(
parser
)
print
(
args
)
if
args
.
no_progress_bar
:
progress_bar
.
enabled
=
False
...
...
@@ -45,11 +44,12 @@ def main():
torch
.
manual_seed
(
args
.
seed
)
# Load dataset
dataset
=
data
.
load_
with_check
(
args
.
data
,
[
'train'
,
'valid'
],
args
.
source_lang
,
args
.
target_lang
)
dataset
=
data
.
load_
dataset
(
args
.
data
,
[
'train'
,
'valid'
],
args
.
source_lang
,
args
.
target_lang
)
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
# record inferred languages in args, so that it's saved in checkpoints
args
.
source_lang
,
args
.
target_lang
=
dataset
.
src
,
dataset
.
dst
print
(
args
)
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
src
,
len
(
dataset
.
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
dst
,
len
(
dataset
.
dst_dict
)))
for
split
in
[
'train'
,
'valid'
]:
...
...
@@ -129,11 +129,10 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
torch
.
manual_seed
(
seed
)
trainer
.
set_seed
(
seed
)
itr
=
dataset
.
dataloader
(
itr
=
dataset
.
train_
dataloader
(
args
.
train_subset
,
num_workers
=
args
.
workers
,
max_tokens
=
args
.
max_tokens
,
seed
=
seed
,
epoch
=
epoch
,
max_positions
=
max_positions
,
max_positions
=
max_positions
,
seed
=
seed
,
epoch
=
epoch
,
sample_without_replacement
=
args
.
sample_without_replacement
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
,
sort_by_source_size
=
(
epoch
<=
args
.
curriculum
))
loss_meter
=
AverageMeter
()
bsz_meter
=
AverageMeter
()
# sentences per batch
...
...
@@ -216,7 +215,7 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
max_positions
,
subset
,
ngpus
):
"""Evaluate the model on the validation set and return the average loss."""
itr
=
dataset
.
dataloader
(
itr
=
dataset
.
eval_
dataloader
(
subset
,
batch_size
=
None
,
max_tokens
=
args
.
max_tokens
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
loss_meter
=
AverageMeter
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment