Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2bba7f81
Commit
2bba7f81
authored
Mar 21, 2019
by
Matthew Carrigan
Browse files
Added a --reduce_memory option to shelve docs to disc instead of keeping them in memory.
parent
8733ffcb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
35 deletions
+66
-35
examples/lm_finetuning/pregenerate_training_data.py
examples/lm_finetuning/pregenerate_training_data.py
+66
-35
No files found.
examples/lm_finetuning/pregenerate_training_data.py
View file @
2bba7f81
from
argparse
import
ArgumentParser
from
pathlib
import
Path
from
tqdm
import
tqdm
,
trange
from
tempfile
import
TemporaryDirectory
import
shelve
from
random
import
random
,
randint
,
shuffle
,
choice
,
sample
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
import
numpy
as
np
import
json
class
DocumentDatabase
:
def
__init__
(
self
,
document_list
):
self
.
document_list
=
document_list
self
.
doc_starts
=
{}
self
.
weighted_doc_samples
=
[]
i
=
0
for
doc_idx
,
doc
in
enumerate
(
document_list
):
self
.
doc_starts
[
doc_idx
]
=
i
self
.
weighted_doc_samples
.
extend
([
doc_idx
]
*
len
(
doc
))
i
+=
len
(
doc
)
def
__init__
(
self
,
reduce_memory
=
False
,
working_dir
=
None
):
if
reduce_memory
:
if
working_dir
is
None
:
self
.
temp_dir
=
TemporaryDirectory
()
self
.
working_dir
=
Path
(
self
.
temp_dir
.
name
)
else
:
self
.
temp_dir
=
None
self
.
working_dir
=
Path
(
working_dir
)
self
.
working_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
document_shelf_filepath
=
self
.
working_dir
/
'shelf.db'
self
.
document_shelf
=
shelve
.
open
(
str
(
self
.
document_shelf_filepath
),
flag
=
'n'
,
protocol
=-
1
)
self
.
documents
=
None
else
:
self
.
documents
=
[]
self
.
document_shelf
=
None
self
.
document_shelf_filepath
=
None
self
.
doc_lengths
=
[]
self
.
doc_cumsum
=
None
self
.
cumsum_max
=
None
self
.
reduce_memory
=
reduce_memory
def
add_document
(
self
,
document
):
if
self
.
reduce_memory
:
current_idx
=
len
(
self
.
doc_lengths
)
self
.
document_shelf
[
str
(
current_idx
)]
=
document
else
:
self
.
documents
.
append
(
document
)
self
.
doc_lengths
.
append
(
len
(
document
))
def
_precalculate_doc_weights
(
self
):
self
.
doc_cumsum
=
np
.
cumsum
(
self
.
doc_lengths
)
self
.
cumsum_max
=
self
.
doc_cumsum
[
-
1
]
def
sample_doc
(
self
,
current_idx
,
sentence_weighted
=
True
):
# Uses the current iteration counter to ensure we don't sample the same doc twice
if
sentence_weighted
:
num_sentences
=
len
(
self
.
document_list
[
current_idx
])
# This very painful line randomly selects a document, weighted by the number of sentences they contain,
# while guaranteeing that it won't return the original document
sampled_val
=
(
(
self
.
doc_starts
[
current_idx
]
+
num_sentences
+
randint
(
0
,
len
(
self
.
weighted_doc_samples
)
-
num_sentences
-
1
))
%
len
(
self
.
weighted_doc_samples
))
sampled_doc_index
=
self
.
weighted_doc_samples
[
sampled_val
]
# With sentence weighting, we sample docs proportionally to their sentence length
if
self
.
doc_cumsum
is
None
or
len
(
self
.
doc_cumsum
)
!=
len
(
self
.
doc_lengths
):
self
.
_precalculate_doc_weights
()
rand_start
=
self
.
doc_cumsum
[
current_idx
]
rand_end
=
rand_start
+
self
.
cumsum_max
-
self
.
doc_lengths
[
current_idx
]
sentence_index
=
randint
(
rand_start
,
rand_end
)
%
self
.
cumsum_max
sampled_doc_index
=
np
.
searchsorted
(
self
.
doc_cumsum
,
sentence_index
,
side
=
'right'
)
else
:
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
sampled_doc_index
=
current_idx
+
randint
(
1
,
len
(
self
.
doc
ument_list
)
-
1
)
sampled_doc_index
=
current_idx
+
randint
(
1
,
len
(
self
.
doc
_lengths
)
-
1
)
assert
sampled_doc_index
!=
current_idx
return
self
.
document_list
[
sampled_doc_index
]
if
self
.
reduce_memory
:
return
self
.
document_shelf
[
str
(
sampled_doc_index
)]
else
:
return
self
.
documents
[
sampled_doc_index
]
def
__len__
(
self
):
return
len
(
self
.
doc
ument_list
)
return
len
(
self
.
doc
_lengths
)
def
__getitem__
(
self
,
item
):
return
self
.
document_list
[
item
]
if
self
.
reduce_memory
:
return
self
.
document_shelf
[
str
(
item
)]
else
:
return
self
.
documents
[
item
]
def
cleanup
(
self
):
if
self
.
document_shelf
is
not
None
:
self
.
document_shelf
.
close
()
def
truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_num_tokens
):
...
...
@@ -200,6 +235,11 @@ def main():
"bert-base-multilingual"
,
"bert-base-chinese"
])
parser
.
add_argument
(
"--do_lower_case"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--reduce_memory"
,
action
=
"store_true"
,
help
=
"Reduce memory usage for large datasets by keeping data on disc rather than in memory"
)
parser
.
add_argument
(
"--working_dir"
,
type
=
Path
,
default
=
None
,
help
=
"Temporary directory to use for --reduce_memory. If not set, uses TemporaryDirectory()"
)
parser
.
add_argument
(
"--epochs_to_generate"
,
type
=
int
,
default
=
3
,
help
=
"Number of epochs of data to pregenerate"
)
parser
.
add_argument
(
"--max_seq_len"
,
type
=
int
,
default
=
128
)
...
...
@@ -212,31 +252,21 @@ def main():
args
=
parser
.
parse_args
()
# TODO Add a low-memory / multiprocessing path for very large datasets
# In this path documents would be stored in a shelf after being tokenized, and multiple processes would convert
# those docs into training examples that would be written out on the fly. This would avoid the need to keep
# the whole training set in memory and would speed up dataset creation at the cost of code complexity.
# In addition, the finetuning script would need to be modified
# to store the training epochs as memmapped arrays.
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
,
do_lower_case
=
args
.
do_lower_case
)
vocab_list
=
list
(
tokenizer
.
vocab
.
keys
())
docs
=
DocumentDatabase
(
reduce_memory
=
args
.
reduce_memory
,
working_dir
=
args
.
working_dir
)
with
args
.
train_corpus
.
open
()
as
f
:
docs
=
[]
doc
=
[]
for
line
in
tqdm
(
f
,
desc
=
"Loading Dataset"
):
for
line
in
tqdm
(
f
,
desc
=
"Loading Dataset"
,
unit
=
" lines"
):
line
=
line
.
strip
()
if
line
==
""
:
docs
.
a
pp
en
d
(
doc
)
docs
.
a
dd_docum
en
t
(
doc
)
doc
=
[]
else
:
tokens
=
tokenizer
.
tokenize
(
line
)
doc
.
append
(
tokens
)
args
.
output_dir
.
mkdir
(
exist_ok
=
True
)
docs
=
DocumentDatabase
(
docs
)
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
# Google BERT doesn't do this, and as a result oversamples shorter docs
for
epoch
in
trange
(
args
.
epochs_to_generate
,
desc
=
"Epoch"
):
epoch_filename
=
args
.
output_dir
/
f
"epoch_
{
epoch
}
.json"
num_instances
=
0
...
...
@@ -257,6 +287,7 @@ def main():
"max_seq_len"
:
args
.
max_seq_len
}
metrics_file
.
write
(
json
.
dumps
(
metrics
))
docs
.
cleanup
()
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment