Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b3caec5a
Commit
b3caec5a
authored
Dec 09, 2018
by
thomwolf
Browse files
adding save checkpoint and loading in examples
parent
85fff78c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
2 deletions
+6
-2
examples/run_classifier.py
examples/run_classifier.py
+5
-1
examples/run_squad.py
examples/run_squad.py
+1
-1
No files found.
examples/run_classifier.py
View file @
b3caec5a
...
@@ -329,7 +329,7 @@ def main():
...
@@ -329,7 +329,7 @@ def main():
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"The output directory where the model checkpoints will be written."
)
help
=
"The output directory where the model
predictions and
checkpoints will be written."
)
## Other parameters
## Other parameters
parser
.
add_argument
(
"--max_seq_length"
,
parser
.
add_argument
(
"--max_seq_length"
,
...
@@ -593,6 +593,10 @@ def main():
...
@@ -593,6 +593,10 @@ def main():
'global_step'
:
global_step
,
'global_step'
:
global_step
,
'loss'
:
tr_loss
/
nb_tr_steps
}
'loss'
:
tr_loss
/
nb_tr_steps
}
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
raise
NotImplementedError
# TODO add save of the configuration file and vocabulary file also ?
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
torch
.
save
(
model_to_save
,
output_model_file
)
output_eval_file
=
os
.
path
.
join
(
args
.
output_dir
,
"eval_results.txt"
)
output_eval_file
=
os
.
path
.
join
(
args
.
output_dir
,
"eval_results.txt"
)
with
open
(
output_eval_file
,
"w"
)
as
writer
:
with
open
(
output_eval_file
,
"w"
)
as
writer
:
logger
.
info
(
"***** Eval results *****"
)
logger
.
info
(
"***** Eval results *****"
)
...
...
examples/run_squad.py
View file @
b3caec5a
...
@@ -690,7 +690,7 @@ def main():
...
@@ -690,7 +690,7 @@ def main():
help
=
"Bert pre-trained model selected in the list: bert-base-uncased, "
help
=
"Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
)
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
)
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The output directory where the model checkpoints will be written."
)
help
=
"The output directory where the model checkpoints
and predictions
will be written."
)
## Other parameters
## Other parameters
parser
.
add_argument
(
"--train_file"
,
default
=
None
,
type
=
str
,
help
=
"SQuAD json for training. E.g., train-v1.1.json"
)
parser
.
add_argument
(
"--train_file"
,
default
=
None
,
type
=
str
,
help
=
"SQuAD json for training. E.g., train-v1.1.json"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment