Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
24d7de44
Commit
24d7de44
authored
May 30, 2018
by
Myle Ott
Browse files
Unify various sharding into ShardedIterator
parent
76b5ecab
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
27 additions
and
35 deletions
+27
-35
eval_lm.py
eval_lm.py
+1
-5
fairseq/data/data_utils.py
fairseq/data/data_utils.py
+21
-22
fairseq/data/language_dataset.py
fairseq/data/language_dataset.py
+4
-4
generate.py
generate.py
+1
-4
No files found.
eval_lm.py
View file @
24d7de44
...
...
@@ -44,11 +44,7 @@ def main(args):
max_positions
=
args
.
max_target_positions
or
1024
,
descending
=
True
,
)
if
args
.
num_shards
>
1
:
if
args
.
shard_id
<
0
or
args
.
shard_id
>=
args
.
num_shards
:
raise
ValueError
(
'--shard-id must be between 0 and num_shards'
)
itr
=
data_utils
.
sharded_iterator
(
itr
,
args
.
num_shards
,
args
.
shard_id
)
itr
=
data_utils
.
ShardedIterator
(
itr
,
args
.
num_shards
,
args
.
shard_id
)
gen_timer
=
StopwatchMeter
()
scorer
=
SequenceScorer
(
models
)
...
...
fairseq/data/data_utils.py
View file @
24d7de44
...
...
@@ -7,6 +7,7 @@
import
contextlib
import
glob
import
itertools
import
math
import
numbers
import
numpy
as
np
...
...
@@ -50,21 +51,31 @@ def fmt_path(path, fmt, *args):
return
os
.
path
.
join
(
path
,
fmt
.
format
(
*
args
))
class
sharded_iterator
(
object
):
class
ShardedIterator
(
object
):
"""A sharded wrapper around an iterable (padded to length)."""
def
__init__
(
self
,
itr
,
num_shards
,
shard_id
):
assert
shard_id
>=
0
and
shard_id
<
num_shards
self
.
itr
=
itr
self
.
num_shards
=
num_shards
self
.
shard_id
=
shard_id
def
__init__
(
self
,
iterable
,
num_shards
,
shard_id
,
fill_value
=
None
):
if
shard_id
<
0
or
shard_id
>=
num_shards
:
raise
ValueError
(
'shard_id must be between 0 and num_shards'
)
self
.
_sharded_len
=
len
(
iterable
)
//
num_shards
if
len
(
iterable
)
%
num_shards
>
0
:
self
.
_sharded_len
+=
1
self
.
itr
=
itertools
.
zip_longest
(
range
(
self
.
_sharded_len
),
itertools
.
islice
(
iterable
,
shard_id
,
len
(
iterable
),
num_shards
),
fillvalue
=
fill_value
,
)
def
__len__
(
self
):
return
len
(
self
.
itr
)
return
self
.
_sharded_len
def
__iter__
(
self
):
for
i
,
v
in
enumerate
(
self
.
itr
):
if
i
%
self
.
num_shards
==
self
.
shard_id
:
yield
v
return
self
def
__next__
(
self
):
return
next
(
self
.
itr
)[
1
]
def
collate_tokens
(
values
,
pad_idx
,
eos_idx
,
left_pad
,
move_eos_to_beginning
=
False
):
...
...
@@ -195,18 +206,6 @@ def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
return
batches
def
mask_batches
(
batch_sampler
,
shard_id
,
num_shards
):
if
num_shards
==
1
:
return
batch_sampler
res
=
[
batch
for
i
,
batch
in
enumerate
(
batch_sampler
)
if
i
%
num_shards
==
shard_id
]
expected_length
=
int
(
math
.
ceil
(
len
(
batch_sampler
)
/
num_shards
))
return
res
+
[[]]
*
(
expected_length
-
len
(
res
))
@
contextlib
.
contextmanager
def
numpy_seed
(
seed
):
"""Context manager which seeds the NumPy PRNG with the specified seed and
...
...
fairseq/data/language_dataset.py
View file @
24d7de44
...
...
@@ -10,7 +10,7 @@ import itertools
import
numpy
as
np
import
torch
from
fairseq.data.data_utils
import
numpy_seed
,
uneven_batches_by_size
,
mask_batches
,
batches_by_size
from
fairseq.data.data_utils
import
numpy_seed
,
uneven_batches_by_size
,
ShardedIterator
,
batches_by_size
class
LanguageDatasets
(
object
):
...
...
@@ -41,7 +41,7 @@ class LanguageDatasets(object):
frozen_batches
=
tuple
(
batches
)
# freeze
def
dataloader
(
b
):
b
=
mask_batches
(
b
,
shard_id
=
shard_id
,
num_shards
=
num_shards
)
# shard dataset
b
=
ShardedIterator
(
b
,
num_shards
,
shard_id
,
fill_value
=
[])
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
collate_fn
=
dataset
.
collater
,
batch_sampler
=
b
)
for
epoch
in
itertools
.
count
(
1
):
...
...
@@ -74,7 +74,7 @@ class LanguageDatasets(object):
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
,
descending
=
descending
,
allow_different_src_lens
=
True
)
batch_sampler
=
mask_batches
(
batch_sampler
,
shard
_id
=
shard_id
,
num_shards
=
num_shards
)
batch_sampler
=
ShardedIterator
(
batch_sampler
,
num_
shard
s
,
shard_id
,
fill_value
=
[]
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
num_workers
=
num_workers
,
collate_fn
=
dataset
.
collater
,
batch_sampler
=
batch_sampler
)
generate.py
View file @
24d7de44
...
...
@@ -58,10 +58,7 @@ def main(args):
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
,
)
if
args
.
num_shards
>
1
:
if
args
.
shard_id
<
0
or
args
.
shard_id
>=
args
.
num_shards
:
raise
ValueError
(
'--shard-id must be between 0 and num_shards'
)
itr
=
data_utils
.
sharded_iterator
(
itr
,
args
.
num_shards
,
args
.
shard_id
)
itr
=
data_utils
.
ShardedIterator
(
itr
,
args
.
num_shards
,
args
.
shard_id
)
# Initialize generator
gen_timer
=
StopwatchMeter
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment