Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
eebc8abb
Commit
eebc8abb
authored
Feb 11, 2019
by
thomwolf
Browse files
clarify and unify model saving logic in examples
parent
81c7e3ec
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
60 additions
and
29 deletions
+60
-29
README.md
README.md
+2
-1
examples/run_classifier.py
examples/run_classifier.py
+30
-12
examples/run_squad.py
examples/run_squad.py
+12
-7
examples/run_swag.py
examples/run_swag.py
+16
-9
No files found.
README.md
View file @
eebc8abb
...
...
@@ -779,7 +779,8 @@ python run_classifier.py \
--train_batch_size 32
\
--learning_rate 2e-5
\
--num_train_epochs 3.0
\
--output_dir /tmp/mrpc_output/
--output_dir /tmp/mrpc_output/
\
--fp16
```
#### SQuAD
...
...
examples/run_classifier.py
View file @
eebc8abb
...
...
@@ -23,7 +23,6 @@ import logging
import
os
import
random
import
sys
from
io
import
open
import
numpy
as
np
import
torch
...
...
@@ -33,7 +32,7 @@ from torch.utils.data.distributed import DistributedSampler
from
tqdm
import
tqdm
,
trange
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.modeling
import
BertForSequenceClassification
from
pytorch_pretrained_bert.modeling
import
BertForSequenceClassification
,
BertConfig
,
WEIGHTS_NAME
,
CONFIG_NAME
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
warmup_linear
...
...
@@ -92,7 +91,7 @@ class DataProcessor(object):
@
classmethod
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
"""Reads a tab separated value file."""
with
open
(
input_file
,
"r
b
"
)
as
f
:
with
open
(
input_file
,
"r"
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
lines
=
[]
for
line
in
reader
:
...
...
@@ -324,6 +323,10 @@ def main():
help
=
"The output directory where the model predictions and checkpoints will be written."
)
## Other parameters
parser
.
add_argument
(
"--cache_dir"
,
default
=
""
,
type
=
str
,
help
=
"Where do you want to store the pre-trained models downloaded from s3"
)
parser
.
add_argument
(
"--max_seq_length"
,
default
=
128
,
type
=
int
,
...
...
@@ -383,9 +386,17 @@ def main():
help
=
"Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.
\n
"
"0 (default value): dynamic loss scaling.
\n
"
"Positive power of 2: static loss scaling value.
\n
"
)
parser
.
add_argument
(
'--server_ip'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
'--server_port'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
args
=
parser
.
parse_args
()
if
args
.
server_ip
and
args
.
server_port
:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import
ptvsd
print
(
"Waiting for debugger attach"
)
ptvsd
.
enable_attach
(
address
=
(
args
.
server_ip
,
args
.
server_port
),
redirect_output
=
True
)
ptvsd
.
wait_for_attach
()
processors
=
{
"cola"
:
ColaProcessor
,
"mnli"
:
MnliProcessor
,
...
...
@@ -451,8 +462,9 @@ def main():
num_train_optimization_steps
=
num_train_optimization_steps
//
torch
.
distributed
.
get_world_size
()
# Prepare model
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
os
.
path
.
join
(
PYTORCH_PRETRAINED_BERT_CACHE
,
'distributed_{}'
.
format
(
args
.
local_rank
))
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
cache_dir
=
os
.
path
.
join
(
PYTORCH_PRETRAINED_BERT_CACHE
,
'distributed_{}'
.
format
(
args
.
local_rank
))
,
cache_dir
=
cache_dir
,
num_labels
=
num_labels
)
if
args
.
fp16
:
model
.
half
()
...
...
@@ -549,15 +561,21 @@ def main():
optimizer
.
zero_grad
()
global_step
+=
1
# Save a trained model
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
if
args
.
do_train
:
# Save a trained model and the associated configuration
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
WEIGHTS_NAME
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
# Load a trained model that you have fine-tuned
model_state_dict
=
torch
.
load
(
output_model_file
)
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
state_dict
=
model_state_dict
,
num_labels
=
num_labels
)
output_config_file
=
os
.
path
.
join
(
args
.
output_dir
,
CONFIG_NAME
)
with
open
(
output_config_file
,
'w'
)
as
f
:
f
.
write
(
model_to_save
.
config
.
to_json_string
())
# Load a trained model and config that you have fine-tuned
config
=
BertConfig
(
output_config_file
)
model
=
BertForSequenceClassification
(
config
,
num_labels
=
num_labels
)
model
.
load_state_dict
(
torch
.
load
(
output_model_file
))
else
:
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
num_labels
=
num_labels
)
model
.
to
(
device
)
if
args
.
do_eval
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
...
...
examples/run_squad.py
View file @
eebc8abb
...
...
@@ -35,7 +35,7 @@ from torch.utils.data.distributed import DistributedSampler
from
tqdm
import
tqdm
,
trange
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.modeling
import
BertForQuestionAnswering
from
pytorch_pretrained_bert.modeling
import
BertForQuestionAnswering
,
BertConfig
,
WEIGHTS_NAME
,
CONFIG_NAME
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
warmup_linear
from
pytorch_pretrained_bert.tokenization
import
(
BasicTokenizer
,
BertTokenizer
,
...
...
@@ -1001,14 +1001,19 @@ def main():
optimizer
.
zero_grad
()
global_step
+=
1
# Save a trained model
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
if
args
.
do_train
:
# Save a trained model and the associated configuration
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
WEIGHTS_NAME
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
# Load a trained model that you have fine-tuned
model_state_dict
=
torch
.
load
(
output_model_file
)
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
,
state_dict
=
model_state_dict
)
output_config_file
=
os
.
path
.
join
(
args
.
output_dir
,
CONFIG_NAME
)
with
open
(
output_config_file
,
'w'
)
as
f
:
f
.
write
(
model_to_save
.
config
.
to_json_string
())
# Load a trained model and config that you have fine-tuned
config
=
BertConfig
(
output_config_file
)
model
=
BertForQuestionAnswering
(
config
)
model
.
load_state_dict
(
torch
.
load
(
output_model_file
))
else
:
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
)
...
...
examples/run_swag.py
View file @
eebc8abb
...
...
@@ -469,18 +469,25 @@ def main():
optimizer
.
zero_grad
()
global_step
+=
1
# Save a trained model
if
args
.
do_train
:
# Save a trained model and the associated configuration
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
WEIGHTS_NAME
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
# Load a trained model that you have fine-tuned
model_state_dict
=
torch
.
load
(
output_model_file
)
model
=
BertForMultipleChoice
.
from_pretrained
(
args
.
bert_model
,
state_dict
=
model_state_dict
,
num_choices
=
4
)
output_config_file
=
os
.
path
.
join
(
args
.
output_dir
,
CONFIG_NAME
)
with
open
(
output_config_file
,
'w'
)
as
f
:
f
.
write
(
model_to_save
.
config
.
to_json_string
())
# Load a trained model and config that you have fine-tuned
config
=
BertConfig
(
output_config_file
)
model
=
BertForMultipleChoice
(
config
,
num_choices
=
4
)
model
.
load_state_dict
(
torch
.
load
(
output_model_file
))
else
:
model
=
BertForMultipleChoice
.
from_pretrained
(
args
.
bert_model
,
num_choices
=
4
)
model
.
to
(
device
)
if
args
.
do_eval
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
eval_examples
=
read_swag_examples
(
os
.
path
.
join
(
args
.
data_dir
,
'val.csv'
),
is_training
=
True
)
eval_features
=
convert_examples_to_features
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment