Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
2ad58885
Commit
2ad58885
authored
Sep 27, 2017
by
Myle Ott
Browse files
Refactor PaddingCollater
parent
4593ebfa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
51 deletions
+72
-51
fairseq/data.py
fairseq/data.py
+72
-51
No files found.
fairseq/data.py
View file @
2ad58885
...
@@ -70,8 +70,8 @@ def load(path, src, dst):
...
@@ -70,8 +70,8 @@ def load(path, src, dst):
dataset
.
splits
[
prefix
]
=
LanguagePairDataset
(
dataset
.
splits
[
prefix
]
=
LanguagePairDataset
(
IndexedInMemoryDataset
(
src_path
),
IndexedInMemoryDataset
(
src_path
),
IndexedInMemoryDataset
(
fmt_path
(
'{}.{}.{}'
,
prefix
,
langcode
,
dst
)),
IndexedInMemoryDataset
(
fmt_path
(
'{}.{}.{}'
,
prefix
,
langcode
,
dst
)),
pad
ding_value
=
dataset
.
src_dict
.
pad
(),
pad
_idx
=
dataset
.
src_dict
.
pad
(),
eos
=
dataset
.
src_dict
.
eos
(),
eos
_idx
=
dataset
.
src_dict
.
eos
(),
)
)
return
dataset
return
dataset
...
@@ -85,6 +85,10 @@ class LanguageDatasets(object):
...
@@ -85,6 +85,10 @@ class LanguageDatasets(object):
self
.
dst_dict
=
dst_dict
self
.
dst_dict
=
dst_dict
self
.
splits
=
{}
self
.
splits
=
{}
assert
self
.
src_dict
.
pad
()
==
self
.
dst_dict
.
pad
()
assert
self
.
src_dict
.
eos
()
==
self
.
dst_dict
.
eos
()
assert
self
.
src_dict
.
unk
()
==
self
.
dst_dict
.
unk
()
def
dataloader
(
self
,
split
,
batch_size
=
1
,
num_workers
=
0
,
def
dataloader
(
self
,
split
,
batch_size
=
1
,
num_workers
=
0
,
max_tokens
=
None
,
seed
=
None
,
epoch
=
1
,
max_tokens
=
None
,
seed
=
None
,
epoch
=
1
,
sample_without_replacement
=
0
,
max_positions
=
1024
):
sample_without_replacement
=
0
,
max_positions
=
1024
):
...
@@ -105,8 +109,9 @@ class LanguageDatasets(object):
...
@@ -105,8 +109,9 @@ class LanguageDatasets(object):
return
torch
.
utils
.
data
.
DataLoader
(
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
dataset
,
num_workers
=
num_workers
,
num_workers
=
num_workers
,
collate_fn
=
PaddingCollater
(
self
.
src_dict
.
pad
()),
collate_fn
=
dataset
.
collater
,
batch_sampler
=
batch_sampler
)
batch_sampler
=
batch_sampler
,
)
def
skip_group_enumerator
(
it
,
ngpus
,
offset
=
0
):
def
skip_group_enumerator
(
it
,
ngpus
,
offset
=
0
):
...
@@ -124,67 +129,83 @@ def skip_group_enumerator(it, ngpus, offset=0):
...
@@ -124,67 +129,83 @@ def skip_group_enumerator(it, ngpus, offset=0):
yield
(
idx
,
res
)
yield
(
idx
,
res
)
class
PaddingCollater
(
object
):
def
__init__
(
self
,
padding_value
=
1
):
self
.
padding_value
=
padding_value
def
__call__
(
self
,
samples
):
def
merge
(
key
,
pad_begin
):
return
self
.
merge_with_pad
([
s
[
key
]
for
s
in
samples
],
pad_begin
)
ntokens
=
sum
(
len
(
s
[
'target'
])
for
s
in
samples
)
return
{
'id'
:
torch
.
LongTensor
([
s
[
'id'
].
item
()
for
s
in
samples
]),
'input_tokens'
:
merge
(
'input_tokens'
,
pad_begin
=
True
),
'input_positions'
:
merge
(
'input_positions'
,
pad_begin
=
True
),
'target'
:
merge
(
'target'
,
pad_begin
=
True
),
'src_tokens'
:
merge
(
'src_tokens'
,
pad_begin
=
False
),
'src_positions'
:
merge
(
'src_positions'
,
pad_begin
=
False
),
'ntokens'
:
ntokens
,
}
def
merge_with_pad
(
self
,
values
,
pad_begin
):
size
=
max
(
v
.
size
(
0
)
for
v
in
values
)
res
=
values
[
0
].
new
(
len
(
values
),
size
).
fill_
(
self
.
padding_value
)
for
i
,
v
in
enumerate
(
values
):
if
pad_begin
:
res
[
i
][
size
-
len
(
v
):].
copy_
(
v
)
else
:
res
[
i
][:
len
(
v
)].
copy_
(
v
)
return
res
class
LanguagePairDataset
(
object
):
class
LanguagePairDataset
(
object
):
def
__init__
(
self
,
src
,
dst
,
pad
ding_value
=
1
,
eos
=
2
):
def
__init__
(
self
,
src
,
dst
,
pad
_idx
,
eos
_idx
):
self
.
src
=
src
self
.
src
=
src
self
.
dst
=
dst
self
.
dst
=
dst
self
.
pad
ding_value
=
padding_value
self
.
pad
_idx
=
pad_idx
self
.
eos
=
eos
self
.
eos
_idx
=
eos
_idx
def
__getitem__
(
self
,
i
):
def
__getitem__
(
self
,
i
):
src
=
self
.
src
[
i
].
long
()
-
1
# subtract 1 for 0-based indexing
source
=
self
.
src
[
i
].
long
()
-
1
target
=
self
.
dst
[
i
].
long
()
-
1
target
=
self
.
dst
[
i
].
long
()
-
1
input
=
target
.
new
(
target
.
size
())
input
[
0
]
=
self
.
eos
input
[
1
:].
copy_
(
target
[:
-
1
])
return
{
return
{
'id'
:
i
,
'id'
:
i
,
'input_tokens'
:
input
,
'source'
:
source
,
'input_positions'
:
self
.
make_positions
(
input
),
'target'
:
target
,
'target'
:
target
,
'src_tokens'
:
src
,
'src_positions'
:
self
.
make_positions
(
src
),
}
}
def
make_positions
(
self
,
x
):
start
=
self
.
padding_value
+
1
return
torch
.
arange
(
start
,
start
+
len
(
x
)).
type_as
(
x
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
src
)
return
len
(
self
.
src
)
def
collater
(
self
,
samples
):
return
LanguagePairDataset
.
collate
(
samples
,
self
.
pad_idx
,
self
.
eos_idx
)
@
staticmethod
def
collate
(
samples
,
pad_idx
,
eos_idx
):
def
merge
(
key
,
left_pad
,
move_eos_to_beginning
=
False
):
return
LanguagePairDataset
.
collate_tokens
(
[
s
[
key
]
for
s
in
samples
],
pad_idx
,
eos_idx
,
left_pad
,
move_eos_to_beginning
)
def
merge_positions
(
key
,
left_pad
):
return
LanguagePairDataset
.
collate_positions
([
s
[
key
]
for
s
in
samples
],
pad_idx
,
left_pad
)
ntokens
=
sum
(
len
(
s
[
'target'
])
for
s
in
samples
)
return
{
'id'
:
torch
.
LongTensor
([
s
[
'id'
].
item
()
for
s
in
samples
]),
'input_tokens'
:
merge
(
'target'
,
left_pad
=
True
,
move_eos_to_beginning
=
True
),
'input_positions'
:
merge_positions
(
'target'
,
left_pad
=
True
),
'target'
:
merge
(
'target'
,
left_pad
=
True
),
'src_tokens'
:
merge
(
'source'
,
left_pad
=
False
),
'src_positions'
:
merge_positions
(
'source'
,
left_pad
=
False
),
'ntokens'
:
ntokens
,
}
@
staticmethod
def
collate_tokens
(
values
,
pad_idx
,
eos_idx
,
left_pad
,
move_eos_to_beginning
):
size
=
max
(
v
.
size
(
0
)
for
v
in
values
)
res
=
values
[
0
].
new
(
len
(
values
),
size
).
fill_
(
pad_idx
)
def
copy_tensor
(
src
,
dst
):
assert
dst
.
numel
()
==
src
.
numel
()
if
move_eos_to_beginning
:
assert
src
[
-
1
]
==
eos_idx
dst
[
0
]
=
eos_idx
dst
[
1
:]
=
src
[:
-
1
]
else
:
dst
.
copy_
(
src
)
for
i
,
v
in
enumerate
(
values
):
if
left_pad
:
copy_tensor
(
v
,
res
[
i
][
size
-
len
(
v
):])
else
:
copy_tensor
(
v
,
res
[
i
][:
len
(
v
)])
return
res
@
staticmethod
def
collate_positions
(
values
,
pad_idx
,
left_pad
):
start
=
pad_idx
+
1
size
=
max
(
v
.
size
(
0
)
for
v
in
values
)
res
=
values
[
0
].
new
(
len
(
values
),
size
).
fill_
(
pad_idx
)
for
i
,
v
in
enumerate
(
values
):
if
left_pad
:
torch
.
arange
(
start
,
start
+
len
(
v
),
out
=
res
[
i
][
size
-
len
(
v
):])
else
:
torch
.
arange
(
start
,
start
+
len
(
v
),
out
=
res
[
i
][:
len
(
v
)])
return
res
def
batches_by_size
(
src
,
batch_size
=
None
,
max_tokens
=
None
,
dst
=
None
,
max_positions
=
1024
):
def
batches_by_size
(
src
,
batch_size
=
None
,
max_tokens
=
None
,
dst
=
None
,
max_positions
=
1024
):
"""Returns batches of indices sorted by size. Sequences of different lengths
"""Returns batches of indices sorted by size. Sequences of different lengths
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment