Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
8f9dd964
Commit
8f9dd964
authored
Oct 31, 2017
by
Myle Ott
Browse files
Improvements to data loader
parent
97d7fcb9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
95 additions
and
99 deletions
+95
-99
fairseq/data.py
fairseq/data.py
+78
-79
generate.py
generate.py
+4
-4
interactive.py
interactive.py
+8
-10
train.py
train.py
+5
-6
No files found.
fairseq/data.py
View file @
8f9dd964
...
@@ -18,61 +18,65 @@ from fairseq.dictionary import Dictionary
...
@@ -18,61 +18,65 @@ from fairseq.dictionary import Dictionary
from
fairseq.indexed_dataset
import
IndexedDataset
,
IndexedInMemoryDataset
from
fairseq.indexed_dataset
import
IndexedDataset
,
IndexedInMemoryDataset
def
load_with_check
(
path
,
load_splits
,
src
=
None
,
dst
=
None
):
def
infer_language_pair
(
path
,
splits
):
"""Loads specified data splits (e.g., test, train or valid) from the
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
specified folder and check that files exist."""
src
,
dst
=
None
,
None
for
filename
in
os
.
listdir
(
path
):
def
find_language_pair
(
files
):
for
split
in
load_splits
:
for
filename
in
files
:
parts
=
filename
.
split
(
'.'
)
parts
=
filename
.
split
(
'.'
)
for
split
in
splits
:
if
parts
[
0
]
==
split
and
parts
[
-
1
]
==
'idx'
:
if
parts
[
0
]
==
split
and
parts
[
-
1
]
==
'idx'
:
return
parts
[
1
].
split
(
'-'
)
src
,
dst
=
parts
[
1
].
split
(
'-'
)
break
return
src
,
dst
def
split_exists
(
split
,
src
,
dst
):
filename
=
'{0}.{1}-{2}.{1}.idx'
.
format
(
split
,
src
,
dst
)
return
os
.
path
.
exists
(
os
.
path
.
join
(
path
,
filename
))
def
load_dictionaries
(
path
,
src_lang
,
dst_lang
):
"""Load dictionaries for a given language pair."""
src_dict
=
Dictionary
.
load
(
os
.
path
.
join
(
path
,
'dict.{}.txt'
.
format
(
src_lang
)))
dst_dict
=
Dictionary
.
load
(
os
.
path
.
join
(
path
,
'dict.{}.txt'
.
format
(
dst_lang
)))
return
src_dict
,
dst_dict
def
load_dataset
(
path
,
load_splits
,
src
=
None
,
dst
=
None
):
"""Loads specified data splits (e.g., test, train or valid) from the
specified folder and check that files exist."""
if
src
is
None
and
dst
is
None
:
if
src
is
None
and
dst
is
None
:
# find language pair automatically
# find language pair automatically
src
,
dst
=
find_language_pair
(
os
.
listdir
(
path
))
src
,
dst
=
infer_language_pair
(
path
,
load_splits
)
if
not
split_exists
(
load_splits
[
0
],
src
,
dst
):
# try reversing src and dst
src
,
dst
=
dst
,
src
def
all_splits_exist
(
src
,
dst
):
for
split
in
load_splits
:
for
split
in
load_splits
:
if
not
split_exists
(
load_splits
[
0
],
src
,
dst
):
filename
=
'{0}.{1}-{2}.{1}.idx'
.
format
(
split
,
src
,
dst
)
raise
ValueError
(
'Data split not found: {}-{} ({})'
.
format
(
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
path
,
filename
)):
src
,
dst
,
split
))
return
False
return
True
dataset
=
load
(
path
,
load_splits
,
src
,
dst
)
return
dataset
def
load
(
path
,
load_splits
,
src
,
dst
):
"""Loads specified data splits (e.g. test, train or valid) from the path."""
# infer langcode
if
all_splits_exist
(
src
,
dst
):
langcode
=
'{}-{}'
.
format
(
src
,
dst
)
langcode
=
'{}-{}'
.
format
(
src
,
dst
)
elif
all_splits_exist
(
dst
,
src
):
langcode
=
'{}-{}'
.
format
(
dst
,
src
)
else
:
raise
Exception
(
'Dataset cannot be loaded from path: '
+
path
)
src_dict
,
dst_dict
=
load_dictionaries
(
path
,
src
,
dst
)
dataset
=
LanguageDatasets
(
src
,
dst
,
src_dict
,
dst_dict
)
def
fmt_path
(
fmt
,
*
args
):
def
fmt_path
(
fmt
,
*
args
):
return
os
.
path
.
join
(
path
,
fmt
.
format
(
*
args
))
return
os
.
path
.
join
(
path
,
fmt
.
format
(
*
args
))
src_dict
=
Dictionary
.
load
(
fmt_path
(
'dict.{}.txt'
,
src
))
dst_dict
=
Dictionary
.
load
(
fmt_path
(
'dict.{}.txt'
,
dst
))
dataset
=
LanguageDatasets
(
src
,
dst
,
src_dict
,
dst_dict
)
for
split
in
load_splits
:
for
split
in
load_splits
:
for
k
in
itertools
.
count
():
for
k
in
itertools
.
count
():
prefix
=
"{}{}"
.
format
(
split
,
k
if
k
>
0
else
''
)
prefix
=
"{}{}"
.
format
(
split
,
k
if
k
>
0
else
''
)
src_path
=
fmt_path
(
'{}.{}.{}'
,
prefix
,
langcode
,
src
)
src_path
=
fmt_path
(
'{}.{}.{}'
,
prefix
,
langcode
,
src
)
dst_path
=
fmt_path
(
'{}.{}.{}'
,
prefix
,
langcode
,
dst
)
if
not
IndexedInMemoryDataset
.
exists
(
src_path
):
if
not
IndexedInMemoryDataset
.
exists
(
src_path
):
break
break
dataset
.
splits
[
prefix
]
=
LanguagePairDataset
(
dataset
.
splits
[
prefix
]
=
LanguagePairDataset
(
IndexedInMemoryDataset
(
src_path
),
IndexedInMemoryDataset
(
src_path
),
IndexedInMemoryDataset
(
fm
t_path
(
'{}.{}.{}'
,
prefix
,
langcode
,
dst
)
),
IndexedInMemoryDataset
(
ds
t_path
),
pad_idx
=
dataset
.
src_dict
.
pad
(),
pad_idx
=
dataset
.
src_dict
.
pad
(),
eos_idx
=
dataset
.
src_dict
.
eos
(),
eos_idx
=
dataset
.
src_dict
.
eos
(),
)
)
...
@@ -92,13 +96,11 @@ class LanguageDatasets(object):
...
@@ -92,13 +96,11 @@ class LanguageDatasets(object):
assert
self
.
src_dict
.
eos
()
==
self
.
dst_dict
.
eos
()
assert
self
.
src_dict
.
eos
()
==
self
.
dst_dict
.
eos
()
assert
self
.
src_dict
.
unk
()
==
self
.
dst_dict
.
unk
()
assert
self
.
src_dict
.
unk
()
==
self
.
dst_dict
.
unk
()
def
dataloader
(
self
,
split
,
batch_size
=
1
,
num_workers
=
0
,
def
train_dataloader
(
self
,
split
,
num_workers
=
0
,
max_tokens
=
None
,
max_tokens
=
None
,
seed
=
None
,
epoch
=
1
,
max_positions
=
(
1024
,
1024
),
seed
=
None
,
epoch
=
1
,
sample_without_replacement
=
0
,
max_positions
=
(
1024
,
1024
),
sample_without_replacement
=
0
,
skip_invalid_size_inputs_valid_test
=
False
,
sort_by_source_size
=
False
):
sort_by_source_size
=
False
):
dataset
=
self
.
splits
[
split
]
dataset
=
self
.
splits
[
split
]
if
split
.
startswith
(
'train'
):
with
numpy_seed
(
seed
):
with
numpy_seed
(
seed
):
batch_sampler
=
shuffled_batches_by_size
(
batch_sampler
=
shuffled_batches_by_size
(
dataset
.
src
,
dataset
.
dst
,
dataset
.
src
,
dataset
.
dst
,
...
@@ -106,22 +108,23 @@ class LanguageDatasets(object):
...
@@ -106,22 +108,23 @@ class LanguageDatasets(object):
sample
=
sample_without_replacement
,
sample
=
sample_without_replacement
,
max_positions
=
max_positions
,
max_positions
=
max_positions
,
sort_by_source_size
=
sort_by_source_size
)
sort_by_source_size
=
sort_by_source_size
)
elif
split
.
startswith
(
'valid'
):
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
num_workers
=
num_workers
,
collate_fn
=
dataset
.
collater
,
batch_sampler
=
batch_sampler
)
def
eval_dataloader
(
self
,
split
,
num_workers
=
0
,
batch_size
=
1
,
max_tokens
=
None
,
consider_dst_sizes
=
True
,
max_positions
=
(
1024
,
1024
),
skip_invalid_size_inputs_valid_test
=
False
):
dataset
=
self
.
splits
[
split
]
dst_dataset
=
dataset
.
dst
if
consider_dst_sizes
else
None
batch_sampler
=
list
(
batches_by_size
(
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
batch_size
,
max_tokens
,
dst
=
dataset
.
dst
,
dataset
.
src
,
dataset
.
dst
,
batch_size
,
max_tokens
,
max_positions
=
max_positions
,
max_positions
=
max_positions
,
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
))
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
))
else
:
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
batch_size
,
max_tokens
,
max_positions
=
max_positions
,
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
))
return
torch
.
utils
.
data
.
DataLoader
(
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
dataset
,
num_workers
=
num_workers
,
collate_fn
=
dataset
.
collater
,
num_workers
=
num_workers
,
batch_sampler
=
batch_sampler
)
collate_fn
=
dataset
.
collater
,
batch_sampler
=
batch_sampler
,
)
def
skip_group_enumerator
(
it
,
ngpus
,
offset
=
0
):
def
skip_group_enumerator
(
it
,
ngpus
,
offset
=
0
):
...
@@ -174,14 +177,15 @@ class LanguagePairDataset(object):
...
@@ -174,14 +177,15 @@ class LanguagePairDataset(object):
return
LanguagePairDataset
.
collate_tokens
(
return
LanguagePairDataset
.
collate_tokens
(
[
s
[
key
]
for
s
in
samples
],
pad_idx
,
eos_idx
,
left_pad
,
move_eos_to_beginning
)
[
s
[
key
]
for
s
in
samples
],
pad_idx
,
eos_idx
,
left_pad
,
move_eos_to_beginning
)
ntokens
=
sum
(
len
(
s
[
'target'
])
for
s
in
samples
)
return
{
return
{
'id'
:
torch
.
LongTensor
([
s
[
'id'
].
item
()
for
s
in
samples
]),
'id'
:
torch
.
LongTensor
([
s
[
'id'
].
item
()
for
s
in
samples
]),
'src_tokens'
:
merge
(
'source'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_SOURCE
),
# we create a shifted version of targets for feeding the previous
# output token(s) into the next decoder step
'input_tokens'
:
merge
(
'target'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
,
'input_tokens'
:
merge
(
'target'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
,
move_eos_to_beginning
=
True
),
move_eos_to_beginning
=
True
),
'target'
:
merge
(
'target'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
),
'target'
:
merge
(
'target'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
),
'src_tokens'
:
merge
(
'source'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_SOURCE
),
'ntokens'
:
sum
(
len
(
s
[
'target'
])
for
s
in
samples
),
'ntokens'
:
ntokens
,
}
}
@
staticmethod
@
staticmethod
...
@@ -218,18 +222,14 @@ def _valid_size(src_size, dst_size, max_positions):
...
@@ -218,18 +222,14 @@ def _valid_size(src_size, dst_size, max_positions):
return
True
return
True
def
batches_by_size
(
src
,
batch_size
=
None
,
max_tokens
=
None
,
dst
=
None
,
def
batches_by_size
(
src
,
dst
,
batch_size
=
None
,
max_tokens
=
None
,
max_positions
=
(
1024
,
1024
),
ignore_invalid_inputs
=
False
):
max_positions
=
(
1024
,
1024
),
ignore_invalid_inputs
=
False
):
"""Returns batches of indices sorted by size. Sequences of different lengths
"""Returns batches of indices sorted by size. Sequences with different
are not allowed in the same batch."""
source lengths are not allowed in the same batch."""
assert
isinstance
(
src
,
IndexedDataset
)
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
assert
dst
is
None
or
isinstance
(
dst
,
IndexedDataset
)
if
max_tokens
is
None
:
if
max_tokens
is
None
:
max_tokens
=
float
(
'Inf'
)
max_tokens
=
float
(
'Inf'
)
sizes
=
src
.
sizes
indices
=
np
.
argsort
(
src
.
sizes
,
kind
=
'mergesort'
)
indices
=
np
.
argsort
(
sizes
,
kind
=
'mergesort'
)
if
dst
is
not
None
:
sizes
=
np
.
maximum
(
sizes
,
dst
.
sizes
)
batch
=
[]
batch
=
[]
...
@@ -238,7 +238,7 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
...
@@ -238,7 +238,7 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
return
False
return
False
if
len
(
batch
)
==
batch_size
:
if
len
(
batch
)
==
batch_size
:
return
True
return
True
if
sizes
[
batch
[
0
]]
!=
sizes
[
next_idx
]:
if
src
.
sizes
[
batch
[
0
]]
!=
src
.
sizes
[
next_idx
]:
return
True
return
True
if
num_tokens
>=
max_tokens
:
if
num_tokens
>=
max_tokens
:
return
True
return
True
...
@@ -247,21 +247,20 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
...
@@ -247,21 +247,20 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
cur_max_size
=
0
cur_max_size
=
0
ignored
=
[]
ignored
=
[]
for
idx
in
indices
:
for
idx
in
indices
:
if
not
_valid_size
(
src
.
sizes
[
idx
],
if
not
_valid_size
(
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
],
max_positions
):
None
if
dst
is
None
else
dst
.
sizes
[
idx
],
max_positions
):
if
ignore_invalid_inputs
:
if
ignore_invalid_inputs
:
ignored
.
append
(
idx
)
ignored
.
append
(
idx
)
continue
continue
raise
Exception
(
"Unable to handle input id {} of size {} / {}."
.
format
(
raise
Exception
(
idx
,
src
.
sizes
[
idx
],
"none"
if
dst
is
None
else
dst
.
sizes
[
idx
]))
"Unable to handle input id {} of size {} / {}."
.
format
(
idx
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
]))
if
yield_batch
(
idx
,
cur_max_size
*
(
len
(
batch
)
+
1
)):
if
yield_batch
(
idx
,
cur_max_size
*
(
len
(
batch
)
+
1
)):
yield
batch
yield
batch
batch
=
[]
batch
=
[]
cur_max_size
=
0
cur_max_size
=
0
batch
.
append
(
idx
)
batch
.
append
(
idx
)
cur_max_size
=
max
(
cur_max_size
,
sizes
[
idx
])
cur_max_size
=
max
(
cur_max_size
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
])
if
len
(
ignored
)
>
0
:
if
len
(
ignored
)
>
0
:
print
(
"Warning! {} samples are either too short or too long "
print
(
"Warning! {} samples are either too short or too long "
...
...
generate.py
View file @
8f9dd964
...
@@ -34,7 +34,7 @@ def main():
...
@@ -34,7 +34,7 @@ def main():
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
# Load dataset
# Load dataset
dataset
=
data
.
load_
with_check
(
args
.
data
,
[
args
.
gen_subset
],
args
.
source_lang
,
args
.
target_lang
)
dataset
=
data
.
load_
dataset
(
args
.
data
,
[
args
.
gen_subset
],
args
.
source_lang
,
args
.
target_lang
)
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
# record inferred languages in args
# record inferred languages in args
args
.
source_lang
,
args
.
target_lang
=
dataset
.
src
,
dataset
.
dst
args
.
source_lang
,
args
.
target_lang
=
dataset
.
src
,
dataset
.
dst
...
@@ -67,8 +67,8 @@ def main():
...
@@ -67,8 +67,8 @@ def main():
# Generate and compute BLEU score
# Generate and compute BLEU score
scorer
=
bleu
.
Scorer
(
dataset
.
dst_dict
.
pad
(),
dataset
.
dst_dict
.
eos
(),
dataset
.
dst_dict
.
unk
())
scorer
=
bleu
.
Scorer
(
dataset
.
dst_dict
.
pad
(),
dataset
.
dst_dict
.
eos
(),
dataset
.
dst_dict
.
unk
())
max_positions
=
min
(
model
.
max_encoder_positions
()
for
model
in
models
)
max_positions
=
min
(
model
.
max_encoder_positions
()
for
model
in
models
)
itr
=
dataset
.
dataloader
(
args
.
gen_subset
,
batch_size
=
args
.
batch_size
,
itr
=
dataset
.
eval_
dataloader
(
max_positions
=
max_positions
,
args
.
gen_subset
,
batch_size
=
args
.
batch_size
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
num_sentences
=
0
num_sentences
=
0
with
progress_bar
(
itr
,
smoothing
=
0
,
leave
=
False
)
as
t
:
with
progress_bar
(
itr
,
smoothing
=
0
,
leave
=
False
)
as
t
:
...
...
interactive.py
View file @
8f9dd964
...
@@ -26,19 +26,17 @@ def main():
...
@@ -26,19 +26,17 @@ def main():
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
# Load dataset
# Load dictionaries
# TODO: load only dictionaries
dataset
=
data
.
load_with_check
(
args
.
data
,
[
'test'
],
args
.
source_lang
,
args
.
target_lang
)
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
# record
infer
red
language
s in args
args
.
source_lang
,
args
.
target_lang
,
_
=
data
.
infer
_
language
_pair
(
args
.
data
,
[
'test'
])
args
.
source_lang
,
args
.
target_lang
=
dataset
.
src
,
dataset
.
dst
src_dict
,
dst_dict
=
data
.
load_dictionaries
(
args
.
data
,
args
.
source_lang
,
args
.
target_lang
)
# Load ensemble
# Load ensemble
print
(
'| loading model(s) from {}'
.
format
(
', '
.
join
(
args
.
path
)))
print
(
'| loading model(s) from {}'
.
format
(
', '
.
join
(
args
.
path
)))
models
=
utils
.
load_ensemble_for_inference
(
args
.
path
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
models
=
utils
.
load_ensemble_for_inference
(
args
.
path
,
src_dict
,
dst_dict
)
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
src
,
len
(
dataset
.
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
args
.
source_lang
,
len
(
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
dst
,
len
(
dataset
.
dst_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
args
.
target_lang
,
len
(
dst_dict
)))
# Optimize ensemble for generation
# Optimize ensemble for generation
for
model
in
models
:
for
model
in
models
:
...
@@ -60,7 +58,7 @@ def main():
...
@@ -60,7 +58,7 @@ def main():
print
(
'Type the input sentence and press return:'
)
print
(
'Type the input sentence and press return:'
)
for
src_str
in
sys
.
stdin
:
for
src_str
in
sys
.
stdin
:
src_str
=
src_str
.
strip
()
src_str
=
src_str
.
strip
()
src_tokens
=
tokenizer
.
Tokenizer
.
tokenize
(
src_str
,
dataset
.
src_dict
,
add_if_not_exist
=
False
).
long
()
src_tokens
=
tokenizer
.
Tokenizer
.
tokenize
(
src_str
,
src_dict
,
add_if_not_exist
=
False
).
long
()
if
use_cuda
:
if
use_cuda
:
src_tokens
=
src_tokens
.
cuda
()
src_tokens
=
src_tokens
.
cuda
()
translations
=
translator
.
generate
(
Variable
(
src_tokens
.
view
(
1
,
-
1
)))
translations
=
translator
.
generate
(
Variable
(
src_tokens
.
view
(
1
,
-
1
)))
...
@@ -74,7 +72,7 @@ def main():
...
@@ -74,7 +72,7 @@ def main():
src_str
=
src_str
,
src_str
=
src_str
,
alignment
=
hypo
[
'alignment'
].
int
().
cpu
(),
alignment
=
hypo
[
'alignment'
].
int
().
cpu
(),
align_dict
=
align_dict
,
align_dict
=
align_dict
,
dst_dict
=
dataset
.
dst_dict
,
dst_dict
=
dst_dict
,
remove_bpe
=
args
.
remove_bpe
)
remove_bpe
=
args
.
remove_bpe
)
print
(
'A
\t
{}'
.
format
(
' '
.
join
(
map
(
str
,
alignment
))))
print
(
'A
\t
{}'
.
format
(
' '
.
join
(
map
(
str
,
alignment
))))
print
(
'H
\t
{}
\t
{}'
.
format
(
hypo
[
'score'
],
hypo_str
))
print
(
'H
\t
{}
\t
{}'
.
format
(
hypo
[
'score'
],
hypo_str
))
...
...
train.py
View file @
8f9dd964
...
@@ -34,7 +34,6 @@ def main():
...
@@ -34,7 +34,6 @@ def main():
options
.
add_model_args
(
parser
)
options
.
add_model_args
(
parser
)
args
=
utils
.
parse_args_and_arch
(
parser
)
args
=
utils
.
parse_args_and_arch
(
parser
)
print
(
args
)
if
args
.
no_progress_bar
:
if
args
.
no_progress_bar
:
progress_bar
.
enabled
=
False
progress_bar
.
enabled
=
False
...
@@ -45,11 +44,12 @@ def main():
...
@@ -45,11 +44,12 @@ def main():
torch
.
manual_seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
# Load dataset
# Load dataset
dataset
=
data
.
load_
with_check
(
args
.
data
,
[
'train'
,
'valid'
],
args
.
source_lang
,
args
.
target_lang
)
dataset
=
data
.
load_
dataset
(
args
.
data
,
[
'train'
,
'valid'
],
args
.
source_lang
,
args
.
target_lang
)
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
# record inferred languages in args, so that it's saved in checkpoints
# record inferred languages in args, so that it's saved in checkpoints
args
.
source_lang
,
args
.
target_lang
=
dataset
.
src
,
dataset
.
dst
args
.
source_lang
,
args
.
target_lang
=
dataset
.
src
,
dataset
.
dst
print
(
args
)
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
src
,
len
(
dataset
.
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
src
,
len
(
dataset
.
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
dst
,
len
(
dataset
.
dst_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
dst
,
len
(
dataset
.
dst_dict
)))
for
split
in
[
'train'
,
'valid'
]:
for
split
in
[
'train'
,
'valid'
]:
...
@@ -129,11 +129,10 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
...
@@ -129,11 +129,10 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
trainer
.
set_seed
(
seed
)
trainer
.
set_seed
(
seed
)
itr
=
dataset
.
dataloader
(
itr
=
dataset
.
train_
dataloader
(
args
.
train_subset
,
num_workers
=
args
.
workers
,
max_tokens
=
args
.
max_tokens
,
args
.
train_subset
,
num_workers
=
args
.
workers
,
max_tokens
=
args
.
max_tokens
,
seed
=
seed
,
epoch
=
epoch
,
max_positions
=
max_positions
,
max_positions
=
max_positions
,
seed
=
seed
,
epoch
=
epoch
,
sample_without_replacement
=
args
.
sample_without_replacement
,
sample_without_replacement
=
args
.
sample_without_replacement
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
,
sort_by_source_size
=
(
epoch
<=
args
.
curriculum
))
sort_by_source_size
=
(
epoch
<=
args
.
curriculum
))
loss_meter
=
AverageMeter
()
loss_meter
=
AverageMeter
()
bsz_meter
=
AverageMeter
()
# sentences per batch
bsz_meter
=
AverageMeter
()
# sentences per batch
...
@@ -216,7 +215,7 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
...
@@ -216,7 +215,7 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
max_positions
,
subset
,
ngpus
):
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
max_positions
,
subset
,
ngpus
):
"""Evaluate the model on the validation set and return the average loss."""
"""Evaluate the model on the validation set and return the average loss."""
itr
=
dataset
.
dataloader
(
itr
=
dataset
.
eval_
dataloader
(
subset
,
batch_size
=
None
,
max_tokens
=
args
.
max_tokens
,
max_positions
=
max_positions
,
subset
,
batch_size
=
None
,
max_tokens
=
args
.
max_tokens
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
loss_meter
=
AverageMeter
()
loss_meter
=
AverageMeter
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment