Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
78462aad
Unverified
Commit
78462aad
authored
Jul 05, 2019
by
Thomas Wolf
Committed by
GitHub
Jul 05, 2019
Browse files
Merge pull request #733 from ceremonious/parallel-generation
Added option to use multiple workers to create training data
parents
781124b0
08ff056c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
20 deletions
+36
-20
examples/lm_finetuning/pregenerate_training_data.py
examples/lm_finetuning/pregenerate_training_data.py
+36
-20
No files found.
examples/lm_finetuning/pregenerate_training_data.py
View file @
78462aad
...
...
@@ -3,6 +3,7 @@ from pathlib import Path
from
tqdm
import
tqdm
,
trange
from
tempfile
import
TemporaryDirectory
import
shelve
from
multiprocessing
import
Pool
from
random
import
random
,
randrange
,
randint
,
shuffle
,
choice
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
...
...
@@ -264,6 +265,28 @@ def create_instances_from_document(
return
instances
def
create_training_file
(
docs
,
vocab_list
,
args
,
epoch_num
):
epoch_filename
=
args
.
output_dir
/
"epoch_{}.json"
.
format
(
epoch_num
)
num_instances
=
0
with
epoch_filename
.
open
(
'w'
)
as
epoch_file
:
for
doc_idx
in
trange
(
len
(
docs
),
desc
=
"Document"
):
doc_instances
=
create_instances_from_document
(
docs
,
doc_idx
,
max_seq_length
=
args
.
max_seq_len
,
short_seq_prob
=
args
.
short_seq_prob
,
masked_lm_prob
=
args
.
masked_lm_prob
,
max_predictions_per_seq
=
args
.
max_predictions_per_seq
,
whole_word_mask
=
args
.
do_whole_word_mask
,
vocab_list
=
vocab_list
)
doc_instances
=
[
json
.
dumps
(
instance
)
for
instance
in
doc_instances
]
for
instance
in
doc_instances
:
epoch_file
.
write
(
instance
+
'
\n
'
)
num_instances
+=
1
metrics_file
=
args
.
output_dir
/
"epoch_{}_metrics.json"
.
format
(
epoch_num
)
with
metrics_file
.
open
(
'w'
)
as
metrics_file
:
metrics
=
{
"num_training_examples"
:
num_instances
,
"max_seq_len"
:
args
.
max_seq_len
}
metrics_file
.
write
(
json
.
dumps
(
metrics
))
def
main
():
parser
=
ArgumentParser
()
parser
.
add_argument
(
'--train_corpus'
,
type
=
Path
,
required
=
True
)
...
...
@@ -277,6 +300,8 @@ def main():
parser
.
add_argument
(
"--reduce_memory"
,
action
=
"store_true"
,
help
=
"Reduce memory usage for large datasets by keeping data on disc rather than in memory"
)
parser
.
add_argument
(
"--num_workers"
,
type
=
int
,
default
=
1
,
help
=
"The number of workers to use to write the files"
)
parser
.
add_argument
(
"--epochs_to_generate"
,
type
=
int
,
default
=
3
,
help
=
"Number of epochs of data to pregenerate"
)
parser
.
add_argument
(
"--max_seq_len"
,
type
=
int
,
default
=
128
)
...
...
@@ -289,6 +314,9 @@ def main():
args
=
parser
.
parse_args
()
if
args
.
num_workers
>
1
and
args
.
reduce_memory
:
raise
ValueError
(
"Cannot use multiple workers while reducing memory"
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
,
do_lower_case
=
args
.
do_lower_case
)
vocab_list
=
list
(
tokenizer
.
vocab
.
keys
())
with
DocumentDatabase
(
reduce_memory
=
args
.
reduce_memory
)
as
docs
:
...
...
@@ -312,26 +340,14 @@ def main():
"sections or paragraphs."
)
args
.
output_dir
.
mkdir
(
exist_ok
=
True
)
if
args
.
num_workers
>
1
:
writer_workers
=
Pool
(
min
(
args
.
num_workers
,
args
.
epochs_to_generate
))
arguments
=
[(
docs
,
vocab_list
,
args
,
idx
)
for
idx
in
range
(
args
.
epochs_to_generate
)]
writer_workers
.
starmap
(
create_training_file
,
arguments
)
else
:
for
epoch
in
trange
(
args
.
epochs_to_generate
,
desc
=
"Epoch"
):
epoch_filename
=
args
.
output_dir
/
f
"epoch_
{
epoch
}
.json"
num_instances
=
0
with
epoch_filename
.
open
(
'w'
)
as
epoch_file
:
for
doc_idx
in
trange
(
len
(
docs
),
desc
=
"Document"
):
doc_instances
=
create_instances_from_document
(
docs
,
doc_idx
,
max_seq_length
=
args
.
max_seq_len
,
short_seq_prob
=
args
.
short_seq_prob
,
masked_lm_prob
=
args
.
masked_lm_prob
,
max_predictions_per_seq
=
args
.
max_predictions_per_seq
,
whole_word_mask
=
args
.
do_whole_word_mask
,
vocab_list
=
vocab_list
)
doc_instances
=
[
json
.
dumps
(
instance
)
for
instance
in
doc_instances
]
for
instance
in
doc_instances
:
epoch_file
.
write
(
instance
+
'
\n
'
)
num_instances
+=
1
metrics_file
=
args
.
output_dir
/
f
"epoch_
{
epoch
}
_metrics.json"
with
metrics_file
.
open
(
'w'
)
as
metrics_file
:
metrics
=
{
"num_training_examples"
:
num_instances
,
"max_seq_len"
:
args
.
max_seq_len
}
metrics_file
.
write
(
json
.
dumps
(
metrics
))
create_training_file
(
docs
,
vocab_list
,
args
,
epoch
)
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment