Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
7d44181d
Commit
7d44181d
authored
Nov 04, 2017
by
Myle Ott
Browse files
Loop over evaluation dataloader in descending order
parent
f442f896
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
4 deletions
+11
-4
fairseq/data.py
fairseq/data.py
+8
-3
train.py
train.py
+3
-1
No files found.
fairseq/data.py
View file @
7d44181d
...
@@ -113,12 +113,14 @@ class LanguageDatasets(object):
...
@@ -113,12 +113,14 @@ class LanguageDatasets(object):
def
eval_dataloader
(
self
,
split
,
num_workers
=
0
,
max_tokens
=
None
,
def
eval_dataloader
(
self
,
split
,
num_workers
=
0
,
max_tokens
=
None
,
max_sentences
=
None
,
max_positions
=
(
1024
,
1024
),
max_sentences
=
None
,
max_positions
=
(
1024
,
1024
),
skip_invalid_size_inputs_valid_test
=
False
):
skip_invalid_size_inputs_valid_test
=
False
,
descending
=
False
):
dataset
=
self
.
splits
[
split
]
dataset
=
self
.
splits
[
split
]
batch_sampler
=
list
(
batches_by_size
(
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
dataset
.
dst
,
max_tokens
,
max_sentences
,
dataset
.
src
,
dataset
.
dst
,
max_tokens
,
max_sentences
,
max_positions
=
max_positions
,
max_positions
=
max_positions
,
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
))
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
,
descending
=
descending
))
return
torch
.
utils
.
data
.
DataLoader
(
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
num_workers
=
num_workers
,
collate_fn
=
dataset
.
collater
,
dataset
,
num_workers
=
num_workers
,
collate_fn
=
dataset
.
collater
,
batch_sampler
=
batch_sampler
)
batch_sampler
=
batch_sampler
)
...
@@ -264,7 +266,8 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
...
@@ -264,7 +266,8 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
def
batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
max_sentences
=
None
,
def
batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
max_sentences
=
None
,
max_positions
=
(
1024
,
1024
),
ignore_invalid_inputs
=
False
):
max_positions
=
(
1024
,
1024
),
ignore_invalid_inputs
=
False
,
descending
=
False
):
"""Returns batches of indices sorted by size. Sequences with different
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
source lengths are not allowed in the same batch."""
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
...
@@ -273,6 +276,8 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
...
@@ -273,6 +276,8 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
if
max_sentences
is
None
:
if
max_sentences
is
None
:
max_sentences
=
float
(
'Inf'
)
max_sentences
=
float
(
'Inf'
)
indices
=
np
.
argsort
(
src
.
sizes
,
kind
=
'mergesort'
)
indices
=
np
.
argsort
(
src
.
sizes
,
kind
=
'mergesort'
)
if
descending
:
indices
=
np
.
flip
(
indices
,
0
)
return
_make_batches
(
return
_make_batches
(
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
ignore_invalid_inputs
,
allow_different_src_lens
=
False
)
ignore_invalid_inputs
,
allow_different_src_lens
=
False
)
...
...
train.py
View file @
7d44181d
...
@@ -222,7 +222,9 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
...
@@ -222,7 +222,9 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
itr
=
dataset
.
eval_dataloader
(
itr
=
dataset
.
eval_dataloader
(
subset
,
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences
,
subset
,
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences
,
max_positions
=
max_positions
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
,
descending
=
True
,
# largest batch first to warm the caching allocator
)
loss_meter
=
AverageMeter
()
loss_meter
=
AverageMeter
()
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment