Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b8e2a9c5
Commit
b8e2a9c5
authored
Apr 22, 2019
by
Matthew Carrigan
Browse files
Made --reduce_memory actually do something in finetune_on_pregenerated
parent
af8a0384
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
2 deletions
+2
-2
examples/lm_finetuning/finetune_on_pregenerated.py
examples/lm_finetuning/finetune_on_pregenerated.py
+2
-2
No files found.
examples/lm_finetuning/finetune_on_pregenerated.py
View file @
b8e2a9c5
...
@@ -74,7 +74,7 @@ class PregeneratedDataset(Dataset):
...
@@ -74,7 +74,7 @@ class PregeneratedDataset(Dataset):
mode
=
'w+'
,
dtype
=
np
.
int32
,
shape
=
(
num_samples
,
seq_len
))
mode
=
'w+'
,
dtype
=
np
.
int32
,
shape
=
(
num_samples
,
seq_len
))
input_masks
=
np
.
memmap
(
filename
=
self
.
working_dir
/
'input_masks.memmap'
,
input_masks
=
np
.
memmap
(
filename
=
self
.
working_dir
/
'input_masks.memmap'
,
shape
=
(
num_samples
,
seq_len
),
mode
=
'w+'
,
dtype
=
np
.
bool
)
shape
=
(
num_samples
,
seq_len
),
mode
=
'w+'
,
dtype
=
np
.
bool
)
segment_ids
=
np
.
memmap
(
filename
=
self
.
working_dir
/
'
input_mask
s.memmap'
,
segment_ids
=
np
.
memmap
(
filename
=
self
.
working_dir
/
'
segment_id
s.memmap'
,
shape
=
(
num_samples
,
seq_len
),
mode
=
'w+'
,
dtype
=
np
.
bool
)
shape
=
(
num_samples
,
seq_len
),
mode
=
'w+'
,
dtype
=
np
.
bool
)
lm_label_ids
=
np
.
memmap
(
filename
=
self
.
working_dir
/
'lm_label_ids.memmap'
,
lm_label_ids
=
np
.
memmap
(
filename
=
self
.
working_dir
/
'lm_label_ids.memmap'
,
shape
=
(
num_samples
,
seq_len
),
mode
=
'w+'
,
dtype
=
np
.
int32
)
shape
=
(
num_samples
,
seq_len
),
mode
=
'w+'
,
dtype
=
np
.
int32
)
...
@@ -283,7 +283,7 @@ def main():
...
@@ -283,7 +283,7 @@ def main():
model
.
train
()
model
.
train
()
for
epoch
in
range
(
args
.
epochs
):
for
epoch
in
range
(
args
.
epochs
):
epoch_dataset
=
PregeneratedDataset
(
epoch
=
epoch
,
training_path
=
args
.
pregenerated_data
,
tokenizer
=
tokenizer
,
epoch_dataset
=
PregeneratedDataset
(
epoch
=
epoch
,
training_path
=
args
.
pregenerated_data
,
tokenizer
=
tokenizer
,
num_data_epochs
=
num_data_epochs
)
num_data_epochs
=
num_data_epochs
,
reduce_memory
=
args
.
reduce_memory
)
if
args
.
local_rank
==
-
1
:
if
args
.
local_rank
==
-
1
:
train_sampler
=
RandomSampler
(
epoch_dataset
)
train_sampler
=
RandomSampler
(
epoch_dataset
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment