Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dbbd6c75
Commit
dbbd6c75
authored
Apr 12, 2019
by
Matthew Carrigan
Browse files
Replaced some randints with cleaner randranges, and added a helpful
error for users whose corpus is just one giant document.
parent
61674333
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
5 deletions
+15
-5
examples/lm_finetuning/pregenerate_training_data.py
examples/lm_finetuning/pregenerate_training_data.py
+15
-5
No files found.
examples/lm_finetuning/pregenerate_training_data.py
View file @
dbbd6c75
...
...
@@ -4,7 +4,7 @@ from tqdm import tqdm, trange
from
tempfile
import
TemporaryDirectory
import
shelve
from
random
import
random
,
randint
,
shuffle
,
choice
,
sample
from
random
import
random
,
randrange
,
randint
,
shuffle
,
choice
,
sample
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
import
numpy
as
np
import
json
...
...
@@ -30,6 +30,8 @@ class DocumentDatabase:
self
.
reduce_memory
=
reduce_memory
def
add_document
(
self
,
document
):
if
not
document
:
return
if
self
.
reduce_memory
:
current_idx
=
len
(
self
.
doc_lengths
)
self
.
document_shelf
[
str
(
current_idx
)]
=
document
...
...
@@ -49,11 +51,11 @@ class DocumentDatabase:
self
.
_precalculate_doc_weights
()
rand_start
=
self
.
doc_cumsum
[
current_idx
]
rand_end
=
rand_start
+
self
.
cumsum_max
-
self
.
doc_lengths
[
current_idx
]
sentence_index
=
rand
int
(
rand_start
,
rand_end
-
1
)
%
self
.
cumsum_max
sentence_index
=
rand
range
(
rand_start
,
rand_end
)
%
self
.
cumsum_max
sampled_doc_index
=
np
.
searchsorted
(
self
.
doc_cumsum
,
sentence_index
,
side
=
'right'
)
else
:
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
sampled_doc_index
=
current_idx
+
rand
int
(
1
,
len
(
self
.
doc_lengths
)
-
1
)
sampled_doc_index
=
(
current_idx
+
rand
range
(
1
,
len
(
self
.
doc_lengths
)
))
%
len
(
self
.
doc_lengths
)
assert
sampled_doc_index
!=
current_idx
if
self
.
reduce_memory
:
return
self
.
document_shelf
[
str
(
sampled_doc_index
)]
...
...
@@ -170,7 +172,7 @@ def create_instances_from_document(
# (first) sentence.
a_end
=
1
if
len
(
current_chunk
)
>=
2
:
a_end
=
rand
int
(
1
,
len
(
current_chunk
)
-
1
)
a_end
=
rand
range
(
1
,
len
(
current_chunk
))
tokens_a
=
[]
for
j
in
range
(
a_end
):
...
...
@@ -186,7 +188,7 @@ def create_instances_from_document(
# Sample a random document, with longer docs being sampled more frequently
random_document
=
doc_database
.
sample_doc
(
current_idx
=
doc_idx
,
sentence_weighted
=
True
)
random_start
=
rand
int
(
0
,
len
(
random_document
)
-
1
)
random_start
=
rand
range
(
0
,
len
(
random_document
))
for
j
in
range
(
random_start
,
len
(
random_document
)):
tokens_b
.
extend
(
random_document
[
j
])
if
len
(
tokens_b
)
>=
target_b_length
:
...
...
@@ -264,6 +266,14 @@ def main():
else
:
tokens
=
tokenizer
.
tokenize
(
line
)
doc
.
append
(
tokens
)
if
doc
:
docs
.
add_document
(
doc
)
# If the last doc didn't end on a newline, make sure it still gets added
if
len
(
docs
)
<=
1
:
exit
(
"ERROR: No document breaks were found in the input file! These are necessary to allow the script to "
"ensure that random NextSentences are not sampled from the same document. Please add blank lines to "
"indicate breaks between documents in your input file. If your dataset does not contain multiple "
"documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, "
"sections or paragraphs."
)
args
.
output_dir
.
mkdir
(
exist_ok
=
True
)
for
epoch
in
trange
(
args
.
epochs_to_generate
,
desc
=
"Epoch"
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment