Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
7d44181d
"src/vscode:/vscode.git/clone" did not exist on "d46421446437511c931afd38ba3aa4908a00bdd9"
Commit
7d44181d
authored
Nov 04, 2017
by
Myle Ott
Browse files
Loop over evaluation dataloader in descending order
parent
f442f896
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
4 deletions
+11
-4
fairseq/data.py
fairseq/data.py
+8
-3
train.py
train.py
+3
-1
No files found.
fairseq/data.py
View file @
7d44181d
...
...
@@ -113,12 +113,14 @@ class LanguageDatasets(object):
def
eval_dataloader
(
self
,
split
,
num_workers
=
0
,
max_tokens
=
None
,
max_sentences
=
None
,
max_positions
=
(
1024
,
1024
),
skip_invalid_size_inputs_valid_test
=
False
):
skip_invalid_size_inputs_valid_test
=
False
,
descending
=
False
):
dataset
=
self
.
splits
[
split
]
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
dataset
.
dst
,
max_tokens
,
max_sentences
,
max_positions
=
max_positions
,
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
))
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
,
descending
=
descending
))
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
num_workers
=
num_workers
,
collate_fn
=
dataset
.
collater
,
batch_sampler
=
batch_sampler
)
...
...
@@ -264,7 +266,8 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
def
batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
max_sentences
=
None
,
max_positions
=
(
1024
,
1024
),
ignore_invalid_inputs
=
False
):
max_positions
=
(
1024
,
1024
),
ignore_invalid_inputs
=
False
,
descending
=
False
):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
...
...
@@ -273,6 +276,8 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
if
max_sentences
is
None
:
max_sentences
=
float
(
'Inf'
)
indices
=
np
.
argsort
(
src
.
sizes
,
kind
=
'mergesort'
)
if
descending
:
indices
=
np
.
flip
(
indices
,
0
)
return
_make_batches
(
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
ignore_invalid_inputs
,
allow_different_src_lens
=
False
)
...
...
train.py
View file @
7d44181d
...
...
@@ -222,7 +222,9 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
itr
=
dataset
.
eval_dataloader
(
subset
,
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
,
descending
=
True
,
# largest batch first to warm the caching allocator
)
loss_meter
=
AverageMeter
()
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment