Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
eebc8abb
Commit
eebc8abb
authored
Feb 11, 2019
by
thomwolf
Browse files
clarify and unify model saving logic in examples
parent
81c7e3ec
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
60 additions
and
29 deletions
+60
-29
README.md
README.md
+2
-1
examples/run_classifier.py
examples/run_classifier.py
+30
-12
examples/run_squad.py
examples/run_squad.py
+12
-7
examples/run_swag.py
examples/run_swag.py
+16
-9
No files found.
README.md
View file @
eebc8abb
...
@@ -779,7 +779,8 @@ python run_classifier.py \
...
@@ -779,7 +779,8 @@ python run_classifier.py \
--train_batch_size 32
\
--train_batch_size 32
\
--learning_rate 2e-5
\
--learning_rate 2e-5
\
--num_train_epochs 3.0
\
--num_train_epochs 3.0
\
--output_dir /tmp/mrpc_output/
--output_dir /tmp/mrpc_output/
\
--fp16
```
```
#### SQuAD
#### SQuAD
...
...
examples/run_classifier.py
View file @
eebc8abb
...
@@ -23,7 +23,6 @@ import logging
...
@@ -23,7 +23,6 @@ import logging
import
os
import
os
import
random
import
random
import
sys
import
sys
from
io
import
open
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -33,7 +32,7 @@ from torch.utils.data.distributed import DistributedSampler
...
@@ -33,7 +32,7 @@ from torch.utils.data.distributed import DistributedSampler
from
tqdm
import
tqdm
,
trange
from
tqdm
import
tqdm
,
trange
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.modeling
import
BertForSequenceClassification
from
pytorch_pretrained_bert.modeling
import
BertForSequenceClassification
,
BertConfig
,
WEIGHTS_NAME
,
CONFIG_NAME
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
warmup_linear
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
warmup_linear
...
@@ -92,7 +91,7 @@ class DataProcessor(object):
...
@@ -92,7 +91,7 @@ class DataProcessor(object):
@
classmethod
@
classmethod
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
"""Reads a tab separated value file."""
"""Reads a tab separated value file."""
with
open
(
input_file
,
"r
b
"
)
as
f
:
with
open
(
input_file
,
"r"
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
lines
=
[]
lines
=
[]
for
line
in
reader
:
for
line
in
reader
:
...
@@ -324,6 +323,10 @@ def main():
...
@@ -324,6 +323,10 @@ def main():
help
=
"The output directory where the model predictions and checkpoints will be written."
)
help
=
"The output directory where the model predictions and checkpoints will be written."
)
## Other parameters
## Other parameters
parser
.
add_argument
(
"--cache_dir"
,
default
=
""
,
type
=
str
,
help
=
"Where do you want to store the pre-trained models downloaded from s3"
)
parser
.
add_argument
(
"--max_seq_length"
,
parser
.
add_argument
(
"--max_seq_length"
,
default
=
128
,
default
=
128
,
type
=
int
,
type
=
int
,
...
@@ -383,9 +386,17 @@ def main():
...
@@ -383,9 +386,17 @@ def main():
help
=
"Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.
\n
"
help
=
"Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.
\n
"
"0 (default value): dynamic loss scaling.
\n
"
"0 (default value): dynamic loss scaling.
\n
"
"Positive power of 2: static loss scaling value.
\n
"
)
"Positive power of 2: static loss scaling value.
\n
"
)
parser
.
add_argument
(
'--server_ip'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
'--server_port'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
server_ip
and
args
.
server_port
:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import
ptvsd
print
(
"Waiting for debugger attach"
)
ptvsd
.
enable_attach
(
address
=
(
args
.
server_ip
,
args
.
server_port
),
redirect_output
=
True
)
ptvsd
.
wait_for_attach
()
processors
=
{
processors
=
{
"cola"
:
ColaProcessor
,
"cola"
:
ColaProcessor
,
"mnli"
:
MnliProcessor
,
"mnli"
:
MnliProcessor
,
...
@@ -451,8 +462,9 @@ def main():
...
@@ -451,8 +462,9 @@ def main():
num_train_optimization_steps
=
num_train_optimization_steps
//
torch
.
distributed
.
get_world_size
()
num_train_optimization_steps
=
num_train_optimization_steps
//
torch
.
distributed
.
get_world_size
()
# Prepare model
# Prepare model
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
os
.
path
.
join
(
PYTORCH_PRETRAINED_BERT_CACHE
,
'distributed_{}'
.
format
(
args
.
local_rank
))
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
cache_dir
=
os
.
path
.
join
(
PYTORCH_PRETRAINED_BERT_CACHE
,
'distributed_{}'
.
format
(
args
.
local_rank
))
,
cache_dir
=
cache_dir
,
num_labels
=
num_labels
)
num_labels
=
num_labels
)
if
args
.
fp16
:
if
args
.
fp16
:
model
.
half
()
model
.
half
()
...
@@ -549,15 +561,21 @@ def main():
...
@@ -549,15 +561,21 @@ def main():
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
global_step
+=
1
global_step
+=
1
# Save a trained model
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
if
args
.
do_train
:
if
args
.
do_train
:
# Save a trained model and the associated configuration
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
WEIGHTS_NAME
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
output_config_file
=
os
.
path
.
join
(
args
.
output_dir
,
CONFIG_NAME
)
# Load a trained model that you have fine-tuned
with
open
(
output_config_file
,
'w'
)
as
f
:
model_state_dict
=
torch
.
load
(
output_model_file
)
f
.
write
(
model_to_save
.
config
.
to_json_string
())
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
state_dict
=
model_state_dict
,
num_labels
=
num_labels
)
# Load a trained model and config that you have fine-tuned
config
=
BertConfig
(
output_config_file
)
model
=
BertForSequenceClassification
(
config
,
num_labels
=
num_labels
)
model
.
load_state_dict
(
torch
.
load
(
output_model_file
))
else
:
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
num_labels
=
num_labels
)
model
.
to
(
device
)
model
.
to
(
device
)
if
args
.
do_eval
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
if
args
.
do_eval
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
...
...
examples/run_squad.py
View file @
eebc8abb
...
@@ -35,7 +35,7 @@ from torch.utils.data.distributed import DistributedSampler
...
@@ -35,7 +35,7 @@ from torch.utils.data.distributed import DistributedSampler
from
tqdm
import
tqdm
,
trange
from
tqdm
import
tqdm
,
trange
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.modeling
import
BertForQuestionAnswering
from
pytorch_pretrained_bert.modeling
import
BertForQuestionAnswering
,
BertConfig
,
WEIGHTS_NAME
,
CONFIG_NAME
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
warmup_linear
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
warmup_linear
from
pytorch_pretrained_bert.tokenization
import
(
BasicTokenizer
,
from
pytorch_pretrained_bert.tokenization
import
(
BasicTokenizer
,
BertTokenizer
,
BertTokenizer
,
...
@@ -1001,14 +1001,19 @@ def main():
...
@@ -1001,14 +1001,19 @@ def main():
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
global_step
+=
1
global_step
+=
1
# Save a trained model
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
if
args
.
do_train
:
if
args
.
do_train
:
# Save a trained model and the associated configuration
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
WEIGHTS_NAME
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
# Load a trained model that you have fine-tuned
output_config_file
=
os
.
path
.
join
(
args
.
output_dir
,
CONFIG_NAME
)
model_state_dict
=
torch
.
load
(
output_model_file
)
with
open
(
output_config_file
,
'w'
)
as
f
:
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
,
state_dict
=
model_state_dict
)
f
.
write
(
model_to_save
.
config
.
to_json_string
())
# Load a trained model and config that you have fine-tuned
config
=
BertConfig
(
output_config_file
)
model
=
BertForQuestionAnswering
(
config
)
model
.
load_state_dict
(
torch
.
load
(
output_model_file
))
else
:
else
:
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
)
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
)
...
...
examples/run_swag.py
View file @
eebc8abb
...
@@ -469,18 +469,25 @@ def main():
...
@@ -469,18 +469,25 @@ def main():
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
global_step
+=
1
global_step
+=
1
# Save a trained model
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
# Load a trained model that you have fine-tuned
if
args
.
do_train
:
model_state_dict
=
torch
.
load
(
output_model_file
)
# Save a trained model and the associated configuration
model
=
BertForMultipleChoice
.
from_pretrained
(
args
.
bert_model
,
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
state_dict
=
model_state_dict
,
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
WEIGHTS_NAME
)
num_choices
=
4
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
output_config_file
=
os
.
path
.
join
(
args
.
output_dir
,
CONFIG_NAME
)
with
open
(
output_config_file
,
'w'
)
as
f
:
f
.
write
(
model_to_save
.
config
.
to_json_string
())
# Load a trained model and config that you have fine-tuned
config
=
BertConfig
(
output_config_file
)
model
=
BertForMultipleChoice
(
config
,
num_choices
=
4
)
model
.
load_state_dict
(
torch
.
load
(
output_model_file
))
else
:
model
=
BertForMultipleChoice
.
from_pretrained
(
args
.
bert_model
,
num_choices
=
4
)
model
.
to
(
device
)
model
.
to
(
device
)
if
args
.
do_eval
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
if
args
.
do_eval
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
eval_examples
=
read_swag_examples
(
os
.
path
.
join
(
args
.
data_dir
,
'val.csv'
),
is_training
=
True
)
eval_examples
=
read_swag_examples
(
os
.
path
.
join
(
args
.
data_dir
,
'val.csv'
),
is_training
=
True
)
eval_features
=
convert_examples_to_features
(
eval_features
=
convert_examples_to_features
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment