Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f1a4e06f
Unverified
Commit
f1a4e06f
authored
Jul 20, 2020
by
Sam Shleifer
Committed by
GitHub
Jul 20, 2020
Browse files
[Fix] seq2seq pack_dataset.py actually packs (#5913)
Huge MT speedup!
parent
32883b31
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
8 deletions
+33
-8
examples/seq2seq/pack_dataset.py
examples/seq2seq/pack_dataset.py
+25
-8
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+8
-0
No files found.
examples/seq2seq/pack_dataset.py
View file @
f1a4e06f
...
@@ -16,13 +16,14 @@ from transformers import AutoTokenizer
...
@@ -16,13 +16,14 @@ from transformers import AutoTokenizer
def
pack_examples
(
tok
,
src_examples
,
tgt_examples
,
max_tokens
=
1024
):
def
pack_examples
(
tok
,
src_examples
,
tgt_examples
,
max_tokens
=
1024
):
finished_src
,
finished_tgt
=
[],
[]
finished_src
,
finished_tgt
=
[],
[]
new_src
,
new_tgt
=
""
,
""
sorted_examples
=
list
(
sorted
(
zip
(
src_examples
,
tgt_examples
),
key
=
lambda
x
:
len
(
x
[
0
])))
sorted_examples
=
list
(
sorted
(
zip
(
src_examples
,
tgt_examples
),
key
=
lambda
x
:
len
(
x
[
0
])))
new_src
,
new_tgt
=
sorted_examples
[
0
]
def
is_too_big
(
strang
):
def
is_too_big
(
strang
):
return
tok
(
strang
,
return_tensors
=
"pt"
).
input_ids
.
shape
[
1
]
>
max_tokens
return
tok
(
strang
,
return_tensors
=
"pt"
).
input_ids
.
shape
[
1
]
>
max_tokens
for
src
,
tgt
in
tqdm
(
sorted_examples
):
for
src
,
tgt
in
tqdm
(
sorted_examples
[
1
:]
):
cand_src
=
new_src
+
" "
+
src
cand_src
=
new_src
+
" "
+
src
cand_tgt
=
new_tgt
+
" "
+
tgt
cand_tgt
=
new_tgt
+
" "
+
tgt
if
is_too_big
(
cand_src
)
or
is_too_big
(
cand_tgt
):
# cant fit, finalize example
if
is_too_big
(
cand_src
)
or
is_too_big
(
cand_tgt
):
# cant fit, finalize example
...
@@ -31,21 +32,37 @@ def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024):
...
@@ -31,21 +32,37 @@ def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024):
new_src
,
new_tgt
=
src
,
tgt
new_src
,
new_tgt
=
src
,
tgt
else
:
# can fit, keep adding
else
:
# can fit, keep adding
new_src
,
new_tgt
=
cand_src
,
cand_tgt
new_src
,
new_tgt
=
cand_src
,
cand_tgt
# import ipdb; ipdb.set_trace()
# cleanup
if
new_src
:
assert
new_tgt
finished_src
.
append
(
new_src
)
finished_tgt
.
append
(
new_tgt
)
return
finished_src
,
finished_tgt
return
finished_src
,
finished_tgt
def
minify
(
src_dir
:
Path
,
dest_dir
:
Path
,
n
:
int
):
"""Write first n lines of each file f in src_dir to dest_dir/f"""
dest_dir
.
mkdir
(
exist_ok
=
True
)
for
path
in
src_dir
.
iterdir
():
new
=
[
x
.
rstrip
()
for
x
in
list
(
path
.
open
().
readlines
())][:
n
]
dest_path
=
dest_dir
.
joinpath
(
path
.
name
)
print
(
dest_path
)
dest_path
.
open
(
"w"
).
write
(
"
\n
"
.
join
(
new
))
def
pack_data_dir
(
tok
,
data_dir
:
Path
,
max_tokens
,
save_path
):
def
pack_data_dir
(
tok
,
data_dir
:
Path
,
max_tokens
,
save_path
):
save_path
=
Path
(
save_path
)
save_path
=
Path
(
save_path
)
save_path
.
mkdir
(
exist_ok
=
True
)
save_path
.
mkdir
(
exist_ok
=
True
)
for
split
in
[
"val"
,
"test"
,
"train"
]:
for
split
in
[
"val"
,
"test"
,
"train"
]:
src_path
,
tgt_path
=
data_dir
/
f
"
{
split
}
.source"
,
data_dir
/
f
"
{
split
}
.target"
src_path
,
tgt_path
=
data_dir
/
f
"
{
split
}
.source"
,
data_dir
/
f
"
{
split
}
.target"
src_docs
=
list
(
Path
(
src_path
).
open
().
readlines
()
)
src_docs
=
[
x
.
rstrip
()
for
x
in
Path
(
src_path
).
open
().
readlines
()
]
tgt_docs
=
list
(
Path
(
tgt_path
).
open
().
readlines
()
)
tgt_docs
=
[
x
.
rstrip
()
for
x
in
Path
(
tgt_path
).
open
().
readlines
()
]
src
,
tgt
=
pack_examples
(
tok
,
src_docs
,
tgt_docs
,
max_tokens
)
packed_src
,
packed_
tgt
=
pack_examples
(
tok
,
src_docs
,
tgt_docs
,
max_tokens
)
print
(
f
"packed
{
split
}
split from
{
len
(
src_docs
)
}
examples ->
{
len
(
src
)
}
."
)
print
(
f
"packed
{
split
}
split from
{
len
(
src_docs
)
}
examples ->
{
len
(
packed_
src
)
}
."
)
Path
(
save_path
/
f
"
{
split
}
.source"
).
open
(
"w"
).
write
(
"
\n
"
.
join
(
src
))
Path
(
save_path
/
f
"
{
split
}
.source"
).
open
(
"w"
).
write
(
"
\n
"
.
join
(
packed_
src
))
Path
(
save_path
/
f
"
{
split
}
.target"
).
open
(
"w"
).
write
(
"
\n
"
.
join
(
tgt
))
Path
(
save_path
/
f
"
{
split
}
.target"
).
open
(
"w"
).
write
(
"
\n
"
.
join
(
packed_
tgt
))
def
packer_cli
():
def
packer_cli
():
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
f1a4e06f
...
@@ -254,11 +254,19 @@ def test_finetune(model):
...
@@ -254,11 +254,19 @@ def test_finetune(model):
def
test_pack_dataset
():
def
test_pack_dataset
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/mbart-large-cc25"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/mbart-large-cc25"
)
tmp_dir
=
Path
(
make_test_data_dir
())
tmp_dir
=
Path
(
make_test_data_dir
())
orig_examples
=
tmp_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
save_dir
=
Path
(
tempfile
.
mkdtemp
(
prefix
=
"packed_"
))
save_dir
=
Path
(
tempfile
.
mkdtemp
(
prefix
=
"packed_"
))
pack_data_dir
(
tokenizer
,
tmp_dir
,
128
,
save_dir
)
pack_data_dir
(
tokenizer
,
tmp_dir
,
128
,
save_dir
)
orig_paths
=
{
x
.
name
for
x
in
tmp_dir
.
iterdir
()}
orig_paths
=
{
x
.
name
for
x
in
tmp_dir
.
iterdir
()}
new_paths
=
{
x
.
name
for
x
in
save_dir
.
iterdir
()}
new_paths
=
{
x
.
name
for
x
in
save_dir
.
iterdir
()}
packed_examples
=
save_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
assert
len
(
packed_examples
)
<
len
(
orig_examples
)
assert
len
(
packed_examples
)
==
1
assert
len
(
packed_examples
[
0
])
==
sum
(
len
(
x
)
for
x
in
orig_examples
)
assert
orig_paths
==
new_paths
assert
orig_paths
==
new_paths
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment