Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d3fcec1a
Commit
d3fcec1a
authored
Dec 11, 2018
by
thomwolf
Browse files
add saving and loading model in examples
parent
93f335ef
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
4 deletions
+18
-4
examples/run_classifier.py
examples/run_classifier.py
+9
-4
examples/run_squad.py
examples/run_squad.py
+9
-0
No files found.
examples/run_classifier.py
View file @
d3fcec1a
...
@@ -546,6 +546,15 @@ def main():
...
@@ -546,6 +546,15 @@ def main():
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
global_step
+=
1
global_step
+=
1
# Save a trained model
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
# Load a trained model that you have fine-tuned
model_state_dict
=
torch
.
load
(
output_model_file
)
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
state_dict
=
model_state_dict
)
if
args
.
do_eval
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
if
args
.
do_eval
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
eval_examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
eval_examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
eval_features
=
convert_examples_to_features
(
eval_features
=
convert_examples_to_features
(
...
@@ -593,10 +602,6 @@ def main():
...
@@ -593,10 +602,6 @@ def main():
'global_step'
:
global_step
,
'global_step'
:
global_step
,
'loss'
:
tr_loss
/
nb_tr_steps
}
'loss'
:
tr_loss
/
nb_tr_steps
}
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
raise
NotImplementedError
# TODO add save of the configuration file and vocabulary file also ?
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
torch
.
save
(
model_to_save
,
output_model_file
)
output_eval_file
=
os
.
path
.
join
(
args
.
output_dir
,
"eval_results.txt"
)
output_eval_file
=
os
.
path
.
join
(
args
.
output_dir
,
"eval_results.txt"
)
with
open
(
output_eval_file
,
"w"
)
as
writer
:
with
open
(
output_eval_file
,
"w"
)
as
writer
:
logger
.
info
(
"***** Eval results *****"
)
logger
.
info
(
"***** Eval results *****"
)
...
...
examples/run_squad.py
View file @
d3fcec1a
...
@@ -911,6 +911,15 @@ def main():
...
@@ -911,6 +911,15 @@ def main():
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
global_step
+=
1
global_step
+=
1
# Save a trained model
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
# Load a trained model that you have fine-tuned
model_state_dict
=
torch
.
load
(
output_model_file
)
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
,
state_dict
=
model_state_dict
)
if
args
.
do_predict
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
if
args
.
do_predict
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
eval_examples
=
read_squad_examples
(
eval_examples
=
read_squad_examples
(
input_file
=
args
.
predict_file
,
is_training
=
False
)
input_file
=
args
.
predict_file
,
is_training
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment