Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7d1ae644
Commit
7d1ae644
authored
Mar 21, 2019
by
Matthew Carrigan
Browse files
Added a --reduce_memory option to the training script to keep training
data on disc as a memmap rather than in memory
parent
2bba7f81
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
20 deletions
+30
-20
examples/lm_finetuning/finetune_on_pregenerated.py
examples/lm_finetuning/finetune_on_pregenerated.py
+26
-9
examples/lm_finetuning/pregenerate_training_data.py
examples/lm_finetuning/pregenerate_training_data.py
+4
-11
No files found.
examples/lm_finetuning/finetune_on_pregenerated.py
View file @
7d1ae644
...
@@ -6,6 +6,7 @@ import json
...
@@ -6,6 +6,7 @@ import json
import
random
import
random
import
numpy
as
np
import
numpy
as
np
from
collections
import
namedtuple
from
collections
import
namedtuple
from
tempfile
import
TemporaryDirectory
from
torch.utils.data
import
DataLoader
,
Dataset
,
RandomSampler
from
torch.utils.data
import
DataLoader
,
Dataset
,
RandomSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
...
@@ -53,8 +54,7 @@ def convert_example_to_features(example, tokenizer, max_seq_length):
...
@@ -53,8 +54,7 @@ def convert_example_to_features(example, tokenizer, max_seq_length):
class
PregeneratedDataset
(
Dataset
):
class
PregeneratedDataset
(
Dataset
):
def
__init__
(
self
,
training_path
,
epoch
,
tokenizer
,
num_data_epochs
):
def
__init__
(
self
,
training_path
,
epoch
,
tokenizer
,
num_data_epochs
,
reduce_memory
=
False
):
# TODO Add an option to memmap the training data if needed (see note in pregenerate_training_data)
self
.
vocab
=
tokenizer
.
vocab
self
.
vocab
=
tokenizer
.
vocab
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
epoch
=
epoch
self
.
epoch
=
epoch
...
@@ -65,11 +65,28 @@ class PregeneratedDataset(Dataset):
...
@@ -65,11 +65,28 @@ class PregeneratedDataset(Dataset):
metrics
=
json
.
loads
(
metrics_file
.
read_text
())
metrics
=
json
.
loads
(
metrics_file
.
read_text
())
num_samples
=
metrics
[
'num_training_examples'
]
num_samples
=
metrics
[
'num_training_examples'
]
seq_len
=
metrics
[
'max_seq_len'
]
seq_len
=
metrics
[
'max_seq_len'
]
input_ids
=
np
.
zeros
(
shape
=
(
num_samples
,
seq_len
),
dtype
=
np
.
int32
)
self
.
temp_dir
=
None
input_masks
=
np
.
zeros
(
shape
=
(
num_samples
,
seq_len
),
dtype
=
np
.
bool
)
self
.
working_dir
=
None
segment_ids
=
np
.
zeros
(
shape
=
(
num_samples
,
seq_len
),
dtype
=
np
.
bool
)
if
reduce_memory
:
lm_label_ids
=
np
.
full
(
shape
=
(
num_samples
,
seq_len
),
dtype
=
np
.
int32
,
fill_value
=-
1
)
self
.
temp_dir
=
TemporaryDirectory
()
is_nexts
=
np
.
zeros
(
shape
=
(
num_samples
,),
dtype
=
np
.
bool
)
self
.
working_dir
=
Path
(
self
.
temp_dir
.
name
)
input_ids
=
np
.
memmap
(
filename
=
self
.
working_dir
/
'input_ids.memmap'
,
mode
=
'w+'
,
dtype
=
np
.
int32
,
shape
=
(
num_samples
,
seq_len
))
input_masks
=
np
.
memmap
(
filename
=
self
.
working_dir
/
'input_masks.memmap'
,
shape
=
(
num_samples
,
seq_len
),
mode
=
'w+'
,
dtype
=
np
.
bool
)
segment_ids
=
np
.
memmap
(
filename
=
self
.
working_dir
/
'input_masks.memmap'
,
shape
=
(
num_samples
,
seq_len
),
mode
=
'w+'
,
dtype
=
np
.
bool
)
lm_label_ids
=
np
.
memmap
(
filename
=
self
.
working_dir
/
'lm_label_ids.memmap'
,
shape
=
(
num_samples
,
seq_len
),
mode
=
'w+'
,
dtype
=
np
.
int32
)
lm_label_ids
[:]
=
-
1
is_nexts
=
np
.
memmap
(
filename
=
self
.
working_dir
/
'is_nexts.memmap'
,
shape
=
(
num_samples
,),
mode
=
'w+'
,
dtype
=
np
.
bool
)
else
:
input_ids
=
np
.
zeros
(
shape
=
(
num_samples
,
seq_len
),
dtype
=
np
.
int32
)
input_masks
=
np
.
zeros
(
shape
=
(
num_samples
,
seq_len
),
dtype
=
np
.
bool
)
segment_ids
=
np
.
zeros
(
shape
=
(
num_samples
,
seq_len
),
dtype
=
np
.
bool
)
lm_label_ids
=
np
.
full
(
shape
=
(
num_samples
,
seq_len
),
dtype
=
np
.
int32
,
fill_value
=-
1
)
is_nexts
=
np
.
zeros
(
shape
=
(
num_samples
,),
dtype
=
np
.
bool
)
logging
.
info
(
f
"Loading training examples for epoch
{
epoch
}
"
)
logging
.
info
(
f
"Loading training examples for epoch
{
epoch
}
"
)
with
data_file
.
open
()
as
f
:
with
data_file
.
open
()
as
f
:
for
i
,
line
in
enumerate
(
tqdm
(
f
,
total
=
num_samples
,
desc
=
"Training examples"
)):
for
i
,
line
in
enumerate
(
tqdm
(
f
,
total
=
num_samples
,
desc
=
"Training examples"
)):
...
@@ -110,6 +127,8 @@ def main():
...
@@ -110,6 +127,8 @@ def main():
choices
=
[
"bert-base-uncased"
,
"bert-large-uncased"
,
"bert-base-cased"
,
choices
=
[
"bert-base-uncased"
,
"bert-large-uncased"
,
"bert-base-cased"
,
"bert-base-multilingual"
,
"bert-base-chinese"
])
"bert-base-multilingual"
,
"bert-base-chinese"
])
parser
.
add_argument
(
"--do_lower_case"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--do_lower_case"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--reduce_memory"
,
action
=
"store_true"
,
help
=
"Store training data as on-disc memmaps to massively reduce memory usage"
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
3
,
help
=
"Number of epochs to train for"
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
3
,
help
=
"Number of epochs to train for"
)
parser
.
add_argument
(
"--local_rank"
,
parser
.
add_argument
(
"--local_rank"
,
...
@@ -311,7 +330,5 @@ def main():
...
@@ -311,7 +330,5 @@ def main():
torch
.
save
(
model_to_save
.
state_dict
(),
str
(
output_model_file
))
torch
.
save
(
model_to_save
.
state_dict
(),
str
(
output_model_file
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
examples/lm_finetuning/pregenerate_training_data.py
View file @
7d1ae644
...
@@ -11,15 +11,10 @@ import json
...
@@ -11,15 +11,10 @@ import json
class
DocumentDatabase
:
class
DocumentDatabase
:
def
__init__
(
self
,
reduce_memory
=
False
,
working_dir
=
None
):
def
__init__
(
self
,
reduce_memory
=
False
):
if
reduce_memory
:
if
reduce_memory
:
if
working_dir
is
None
:
self
.
temp_dir
=
TemporaryDirectory
()
self
.
temp_dir
=
TemporaryDirectory
()
self
.
working_dir
=
Path
(
self
.
temp_dir
.
name
)
self
.
working_dir
=
Path
(
self
.
temp_dir
.
name
)
else
:
self
.
temp_dir
=
None
self
.
working_dir
=
Path
(
working_dir
)
self
.
working_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
document_shelf_filepath
=
self
.
working_dir
/
'shelf.db'
self
.
document_shelf_filepath
=
self
.
working_dir
/
'shelf.db'
self
.
document_shelf
=
shelve
.
open
(
str
(
self
.
document_shelf_filepath
),
self
.
document_shelf
=
shelve
.
open
(
str
(
self
.
document_shelf_filepath
),
flag
=
'n'
,
protocol
=-
1
)
flag
=
'n'
,
protocol
=-
1
)
...
@@ -237,8 +232,6 @@ def main():
...
@@ -237,8 +232,6 @@ def main():
parser
.
add_argument
(
"--reduce_memory"
,
action
=
"store_true"
,
parser
.
add_argument
(
"--reduce_memory"
,
action
=
"store_true"
,
help
=
"Reduce memory usage for large datasets by keeping data on disc rather than in memory"
)
help
=
"Reduce memory usage for large datasets by keeping data on disc rather than in memory"
)
parser
.
add_argument
(
"--working_dir"
,
type
=
Path
,
default
=
None
,
help
=
"Temporary directory to use for --reduce_memory. If not set, uses TemporaryDirectory()"
)
parser
.
add_argument
(
"--epochs_to_generate"
,
type
=
int
,
default
=
3
,
parser
.
add_argument
(
"--epochs_to_generate"
,
type
=
int
,
default
=
3
,
help
=
"Number of epochs of data to pregenerate"
)
help
=
"Number of epochs of data to pregenerate"
)
...
@@ -254,7 +247,7 @@ def main():
...
@@ -254,7 +247,7 @@ def main():
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
,
do_lower_case
=
args
.
do_lower_case
)
vocab_list
=
list
(
tokenizer
.
vocab
.
keys
())
vocab_list
=
list
(
tokenizer
.
vocab
.
keys
())
docs
=
DocumentDatabase
(
reduce_memory
=
args
.
reduce_memory
,
working_dir
=
args
.
working_dir
)
docs
=
DocumentDatabase
(
reduce_memory
=
args
.
reduce_memory
)
with
args
.
train_corpus
.
open
()
as
f
:
with
args
.
train_corpus
.
open
()
as
f
:
doc
=
[]
doc
=
[]
for
line
in
tqdm
(
f
,
desc
=
"Loading Dataset"
,
unit
=
" lines"
):
for
line
in
tqdm
(
f
,
desc
=
"Loading Dataset"
,
unit
=
" lines"
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment