Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f93229b9
Commit
f93229b9
authored
Aug 26, 2019
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Aug 26, 2019
Browse files
Internal change
PiperOrigin-RevId: 265510206
parent
1e48a60a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
8 deletions
+38
-8
official/bert/common_flags.py
official/bert/common_flags.py
+4
-0
official/bert/run_classifier.py
official/bert/run_classifier.py
+0
-4
official/bert/run_squad.py
official/bert/run_squad.py
+34
-4
No files found.
official/bert/common_flags.py
View file @
f93229b9
...
...
@@ -27,6 +27,10 @@ def define_common_bert_flags():
flags
.
DEFINE_string
(
'model_dir'
,
None
,
(
'The directory where the model weights and training/evaluation summaries '
'are stored. If not specified, save to /tmp/bert20/.'
))
flags
.
DEFINE_string
(
'model_export_path'
,
None
,
'Path to the directory, where trainined model will be '
'exported.'
)
flags
.
DEFINE_string
(
'tpu'
,
''
,
'TPU address to connect to.'
)
flags
.
DEFINE_string
(
'init_checkpoint'
,
None
,
...
...
official/bert/run_classifier.py
View file @
f93229b9
...
...
@@ -48,10 +48,6 @@ flags.DEFINE_string('train_data_path', None,
'Path to training data for BERT classifier.'
)
flags
.
DEFINE_string
(
'eval_data_path'
,
None
,
'Path to evaluation data for BERT classifier.'
)
flags
.
DEFINE_string
(
'model_export_path'
,
None
,
'Path to the directory, where trainined model will be '
'exported.'
)
# Model training specific flags.
flags
.
DEFINE_string
(
'input_meta_data_path'
,
None
,
...
...
official/bert/run_squad.py
View file @
f93229b9
...
...
@@ -31,6 +31,7 @@ import tensorflow as tf
from
official.bert
import
bert_models
from
official.bert
import
common_flags
from
official.bert
import
input_pipeline
from
official.bert
import
model_saving_utils
from
official.bert
import
model_training_utils
from
official.bert
import
modeling
from
official.bert
import
optimization
...
...
@@ -39,8 +40,13 @@ from official.bert import tokenization
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
tpu_lib
flags
.
DEFINE_bool
(
'do_train'
,
False
,
'Whether to run training.'
)
flags
.
DEFINE_bool
(
'do_predict'
,
False
,
'Whether to run eval on the dev set.'
)
flags
.
DEFINE_enum
(
'mode'
,
'train'
,
[
'train'
,
'predict'
,
'export_only'
],
'One of {"train", "predict", "export_only"}. `train`: '
'trains the model and evaluates in the meantime. '
'`predict`: predict answers from the squad json file. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.'
)
flags
.
DEFINE_string
(
'train_data_path'
,
''
,
'Training data path with train tfrecords.'
)
flags
.
DEFINE_string
(
...
...
@@ -311,6 +317,26 @@ def predict_squad(strategy, input_meta_data):
verbose
=
FLAGS
.
verbose_logging
)
def
export_squad
(
model_export_path
,
input_meta_data
):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
Raises:
Export path is not specified, got an empty string or None.
"""
if
not
model_export_path
:
raise
ValueError
(
'Export path is not specified: %s'
%
model_export_path
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
squad_model
,
_
=
bert_models
.
squad_model
(
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
)
model_saving_utils
.
export_bert_model
(
model_export_path
,
model
=
squad_model
,
checkpoint_dir
=
FLAGS
.
model_dir
)
def
main
(
_
):
# Users should always run this script under TF 2.x
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
...
...
@@ -318,6 +344,10 @@ def main(_):
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
if
FLAGS
.
mode
==
'export_only'
:
export_squad
(
FLAGS
.
model_export_path
,
input_meta_data
)
return
strategy
=
None
if
FLAGS
.
strategy_type
==
'mirror'
:
strategy
=
tf
.
distribute
.
MirroredStrategy
()
...
...
@@ -330,9 +360,9 @@ def main(_):
else
:
raise
ValueError
(
'The distribution strategy type is not supported: %s'
%
FLAGS
.
strategy_type
)
if
FLAGS
.
do_
train
:
if
FLAGS
.
mode
==
'
train
'
:
train_squad
(
strategy
,
input_meta_data
)
if
FLAGS
.
do_
predict
:
if
FLAGS
.
mode
==
'
predict
'
:
predict_squad
(
strategy
,
input_meta_data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment