Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dbbd6c75
Commit
dbbd6c75
authored
Apr 12, 2019
by
Matthew Carrigan
Browse files
Replaced some randints with cleaner randranges, and added a helpful
error for users whose corpus is just one giant document.
parent
61674333
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
5 deletions
+15
-5
examples/lm_finetuning/pregenerate_training_data.py
examples/lm_finetuning/pregenerate_training_data.py
+15
-5
No files found.
examples/lm_finetuning/pregenerate_training_data.py
View file @
dbbd6c75
...
...
@@ -4,7 +4,7 @@ from tqdm import tqdm, trange
from
tempfile
import
TemporaryDirectory
import
shelve
from
random
import
random
,
randint
,
shuffle
,
choice
,
sample
from
random
import
random
,
randrange
,
randint
,
shuffle
,
choice
,
sample
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
import
numpy
as
np
import
json
...
...
@@ -30,6 +30,8 @@ class DocumentDatabase:
self
.
reduce_memory
=
reduce_memory
def
add_document
(
self
,
document
):
if
not
document
:
return
if
self
.
reduce_memory
:
current_idx
=
len
(
self
.
doc_lengths
)
self
.
document_shelf
[
str
(
current_idx
)]
=
document
...
...
@@ -49,11 +51,11 @@ class DocumentDatabase:
self
.
_precalculate_doc_weights
()
rand_start
=
self
.
doc_cumsum
[
current_idx
]
rand_end
=
rand_start
+
self
.
cumsum_max
-
self
.
doc_lengths
[
current_idx
]
sentence_index
=
rand
int
(
rand_start
,
rand_end
-
1
)
%
self
.
cumsum_max
sentence_index
=
rand
range
(
rand_start
,
rand_end
)
%
self
.
cumsum_max
sampled_doc_index
=
np
.
searchsorted
(
self
.
doc_cumsum
,
sentence_index
,
side
=
'right'
)
else
:
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
sampled_doc_index
=
current_idx
+
rand
int
(
1
,
len
(
self
.
doc_lengths
)
-
1
)
sampled_doc_index
=
(
current_idx
+
rand
range
(
1
,
len
(
self
.
doc_lengths
)
))
%
len
(
self
.
doc_lengths
)
assert
sampled_doc_index
!=
current_idx
if
self
.
reduce_memory
:
return
self
.
document_shelf
[
str
(
sampled_doc_index
)]
...
...
@@ -170,7 +172,7 @@ def create_instances_from_document(
# (first) sentence.
a_end
=
1
if
len
(
current_chunk
)
>=
2
:
a_end
=
rand
int
(
1
,
len
(
current_chunk
)
-
1
)
a_end
=
rand
range
(
1
,
len
(
current_chunk
))
tokens_a
=
[]
for
j
in
range
(
a_end
):
...
...
@@ -186,7 +188,7 @@ def create_instances_from_document(
# Sample a random document, with longer docs being sampled more frequently
random_document
=
doc_database
.
sample_doc
(
current_idx
=
doc_idx
,
sentence_weighted
=
True
)
random_start
=
rand
int
(
0
,
len
(
random_document
)
-
1
)
random_start
=
rand
range
(
0
,
len
(
random_document
))
for
j
in
range
(
random_start
,
len
(
random_document
)):
tokens_b
.
extend
(
random_document
[
j
])
if
len
(
tokens_b
)
>=
target_b_length
:
...
...
@@ -264,6 +266,14 @@ def main():
else
:
tokens
=
tokenizer
.
tokenize
(
line
)
doc
.
append
(
tokens
)
if
doc
:
docs
.
add_document
(
doc
)
# If the last doc didn't end on a newline, make sure it still gets added
if
len
(
docs
)
<=
1
:
exit
(
"ERROR: No document breaks were found in the input file! These are necessary to allow the script to "
"ensure that random NextSentences are not sampled from the same document. Please add blank lines to "
"indicate breaks between documents in your input file. If your dataset does not contain multiple "
"documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, "
"sections or paragraphs."
)
args
.
output_dir
.
mkdir
(
exist_ok
=
True
)
for
epoch
in
trange
(
args
.
epochs_to_generate
,
desc
=
"Epoch"
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment