Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
283500ff
Unverified
Commit
283500ff
authored
Jul 16, 2020
by
Sam Shleifer
Committed by
GitHub
Jul 16, 2020
Browse files
[seq2seq] pack_dataset.py rewrites dataset in max_tokens format (#5819)
parent
c45d7a70
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
0 deletions
+74
-0
examples/seq2seq/pack_dataset.py
examples/seq2seq/pack_dataset.py
+63
-0
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+11
-0
No files found.
examples/seq2seq/pack_dataset.py
0 → 100644
View file @
283500ff
"""Fill examples with bitext up to max_tokens without breaking up examples.
[['I went', 'yo fui'],
['to the store', 'a la tienda']
]
=> ['I went to the store', 'yo fui a la tienda']
"""
import
argparse
from
pathlib
import
Path
from
tqdm
import
tqdm
from
transformers
import
AutoTokenizer
def
pack_examples
(
tok
,
src_examples
,
tgt_examples
,
max_tokens
=
1024
):
finished_src
,
finished_tgt
=
[],
[]
new_src
,
new_tgt
=
""
,
""
sorted_examples
=
list
(
sorted
(
zip
(
src_examples
,
tgt_examples
),
key
=
lambda
x
:
len
(
x
[
0
])))
def
is_too_big
(
strang
):
return
tok
(
strang
,
return_tensors
=
"pt"
).
input_ids
.
shape
[
1
]
>
max_tokens
for
src
,
tgt
in
tqdm
(
sorted_examples
):
cand_src
=
new_src
+
" "
+
src
cand_tgt
=
new_tgt
+
" "
+
tgt
if
is_too_big
(
cand_src
)
or
is_too_big
(
cand_tgt
):
# cant fit, finalize example
finished_src
.
append
(
new_src
)
finished_tgt
.
append
(
new_tgt
)
new_src
,
new_tgt
=
src
,
tgt
else
:
# can fit, keep adding
new_src
,
new_tgt
=
cand_src
,
cand_tgt
return
finished_src
,
finished_tgt
def
pack_data_dir
(
tok
,
data_dir
:
Path
,
max_tokens
,
save_path
):
save_path
=
Path
(
save_path
)
save_path
.
mkdir
(
exist_ok
=
True
)
for
split
in
[
"val"
,
"test"
,
"train"
]:
src_path
,
tgt_path
=
data_dir
/
f
"
{
split
}
.source"
,
data_dir
/
f
"
{
split
}
.target"
src_docs
=
list
(
Path
(
src_path
).
open
().
readlines
())
tgt_docs
=
list
(
Path
(
tgt_path
).
open
().
readlines
())
src
,
tgt
=
pack_examples
(
tok
,
src_docs
,
tgt_docs
,
max_tokens
)
print
(
f
"packed
{
split
}
split from
{
len
(
src_docs
)
}
examples ->
{
len
(
src
)
}
."
)
Path
(
save_path
/
f
"
{
split
}
.source"
).
open
(
"w"
).
write
(
"
\n
"
.
join
(
src
))
Path
(
save_path
/
f
"
{
split
}
.target"
).
open
(
"w"
).
write
(
"
\n
"
.
join
(
tgt
))
def
packer_cli
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--tok_name"
,
type
=
str
,
help
=
"like facebook/bart-large-cnn,t5-base, etc."
)
parser
.
add_argument
(
"--max_seq_len"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--data_dir"
,
type
=
str
)
parser
.
add_argument
(
"--save_path"
,
type
=
str
)
args
=
parser
.
parse_args
()
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
tok_name
)
return
pack_data_dir
(
tokenizer
,
Path
(
args
.
data_dir
),
args
.
max_seq_len
,
args
.
save_path
)
if
__name__
==
"__main__"
:
packer_cli
()
examples/seq2seq/test_seq2seq_examples.py
View file @
283500ff
...
@@ -16,6 +16,7 @@ from transformers.testing_utils import require_multigpu
...
@@ -16,6 +16,7 @@ from transformers.testing_utils import require_multigpu
from
.distillation
import
distill_main
,
evaluate_checkpoint
from
.distillation
import
distill_main
,
evaluate_checkpoint
from
.finetune
import
main
from
.finetune
import
main
from
.pack_dataset
import
pack_data_dir
from
.run_eval
import
generate_summaries_or_translations
,
run_generate
from
.run_eval
import
generate_summaries_or_translations
,
run_generate
from
.utils
import
SummarizationDataset
,
lmap
,
load_json
from
.utils
import
SummarizationDataset
,
lmap
,
load_json
...
@@ -249,6 +250,16 @@ def test_finetune(model):
...
@@ -249,6 +250,16 @@ def test_finetune(model):
assert
bart
.
decoder
.
embed_tokens
==
bart
.
shared
assert
bart
.
decoder
.
embed_tokens
==
bart
.
shared
def
test_pack_dataset
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/mbart-large-cc25"
)
tmp_dir
=
Path
(
make_test_data_dir
())
save_dir
=
Path
(
tempfile
.
mkdtemp
(
prefix
=
"packed_"
))
pack_data_dir
(
tokenizer
,
tmp_dir
,
128
,
save_dir
)
orig_paths
=
{
x
.
name
for
x
in
tmp_dir
.
iterdir
()}
new_paths
=
{
x
.
name
for
x
in
save_dir
.
iterdir
()}
assert
orig_paths
==
new_paths
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
[
"tok"
],
[
pytest
.
param
(
T5_TINY
),
pytest
.
param
(
BART_TINY
),
pytest
.
param
(
MBART_TINY
),
pytest
.
param
(
MARIAN_TINY
)]
[
"tok"
],
[
pytest
.
param
(
T5_TINY
),
pytest
.
param
(
BART_TINY
),
pytest
.
param
(
MBART_TINY
),
pytest
.
param
(
MARIAN_TINY
)]
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment