Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e238e3d5
"...resnet50_tensorflow.git" did not exist on "c5763f442973d4852d34d9999f1cf9cf11499dea"
Unverified
Commit
e238e3d5
authored
Jul 17, 2020
by
Sam Shleifer
Committed by
GitHub
Jul 17, 2020
Browse files
[seq2seq] Don't copy self.source in sortishsampler (#5818)
parent
2e4624b4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
10 deletions
+3
-10
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+3
-10
No files found.
examples/seq2seq/utils.py
View file @
e238e3d5
...
...
@@ -144,16 +144,9 @@ class SummarizationDataset(Dataset):
batch
=
{
"input_ids"
:
source_ids
,
"attention_mask"
:
source_mask
,
"decoder_input_ids"
:
y
}
return
batch
@
property
def
src_lens
(
self
):
# Can delete?
return
lmap
(
len
,
self
.
source
)
@
property
def
tgt_lens
(
self
):
return
lmap
(
len
,
self
.
target
)
def
make_sortish_sampler
(
self
,
batch_size
):
return
SortishSampler
(
self
.
source
,
batch_size
)
lens
=
[
x
[
"input_ids"
].
ne
(
self
.
pad_token_id
).
sum
()
for
x
in
self
.
source
]
return
SortishSampler
(
lens
,
batch_size
)
class
SortishSampler
(
Sampler
):
...
...
@@ -163,7 +156,7 @@ class SortishSampler(Sampler):
self
.
data
,
self
.
bs
=
data
,
batch_size
def
key
(
self
,
i
):
return
len
(
self
.
data
[
i
]
)
return
self
.
data
[
i
]
def
__len__
(
self
)
->
int
:
return
len
(
self
.
data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment