Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
174cdbcc
Commit
174cdbcc
authored
Dec 09, 2018
by
thomwolf
Browse files
adding save checkpoint and loading in examples
parent
1db916b5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
2 deletions
+6
-2
examples/run_classifier.py
examples/run_classifier.py
+5
-1
examples/run_squad.py
examples/run_squad.py
+1
-1
No files found.
examples/run_classifier.py
View file @
174cdbcc
...
@@ -359,7 +359,7 @@ def main():
...
@@ -359,7 +359,7 @@ def main():
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"The output directory where the model checkpoints will be written."
)
help
=
"The output directory where the model
predictions and
checkpoints will be written."
)
## Other parameters
## Other parameters
parser
.
add_argument
(
"--max_seq_length"
,
parser
.
add_argument
(
"--max_seq_length"
,
...
@@ -626,6 +626,10 @@ def main():
...
@@ -626,6 +626,10 @@ def main():
'global_step'
:
global_step
,
'global_step'
:
global_step
,
'loss'
:
tr_loss
/
nb_tr_steps
}
'loss'
:
tr_loss
/
nb_tr_steps
}
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
raise
NotImplementedError
# TODO add save of the configuration file and vocabulary file also ?
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
torch
.
save
(
model_to_save
,
output_model_file
)
output_eval_file
=
os
.
path
.
join
(
args
.
output_dir
,
"eval_results.txt"
)
output_eval_file
=
os
.
path
.
join
(
args
.
output_dir
,
"eval_results.txt"
)
with
open
(
output_eval_file
,
"w"
)
as
writer
:
with
open
(
output_eval_file
,
"w"
)
as
writer
:
logger
.
info
(
"***** Eval results *****"
)
logger
.
info
(
"***** Eval results *****"
)
...
...
examples/run_squad.py
View file @
174cdbcc
...
@@ -706,7 +706,7 @@ def main():
...
@@ -706,7 +706,7 @@ def main():
help
=
"Bert pre-trained model selected in the list: bert-base-uncased, "
help
=
"Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
)
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
)
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The output directory where the model checkpoints will be written."
)
help
=
"The output directory where the model checkpoints
and predictions
will be written."
)
## Other parameters
## Other parameters
parser
.
add_argument
(
"--train_file"
,
default
=
None
,
type
=
str
,
help
=
"SQuAD json for training. E.g., train-v1.1.json"
)
parser
.
add_argument
(
"--train_file"
,
default
=
None
,
type
=
str
,
help
=
"SQuAD json for training. E.g., train-v1.1.json"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment