"llm/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "1ae0750a21d206227a1990af74cbc68912d63ea5"
Unverified Commit db778817 authored by Katherine Wu's avatar Katherine Wu Committed by GitHub
Browse files

Add export savedmodel to wide_deep (#4041)

parent be7da421
...@@ -246,7 +246,6 @@ class MNISTArgParser(argparse.ArgumentParser): ...@@ -246,7 +246,6 @@ class MNISTArgParser(argparse.ArgumentParser):
super(MNISTArgParser, self).__init__(parents=[ super(MNISTArgParser, self).__init__(parents=[
parsers.BaseParser(), parsers.BaseParser(),
parsers.ImageModelParser(), parsers.ImageModelParser(),
parsers.ExportParser(),
]) ])
self.set_defaults( self.set_defaults(
......
...@@ -465,7 +465,6 @@ class ResnetArgParser(argparse.ArgumentParser): ...@@ -465,7 +465,6 @@ class ResnetArgParser(argparse.ArgumentParser):
parsers.BaseParser(), parsers.BaseParser(),
parsers.PerformanceParser(), parsers.PerformanceParser(),
parsers.ImageModelParser(), parsers.ImageModelParser(),
parsers.ExportParser(),
parsers.BenchmarkParser(), parsers.BenchmarkParser(),
]) ])
......
...@@ -104,12 +104,13 @@ class BaseParser(argparse.ArgumentParser): ...@@ -104,12 +104,13 @@ class BaseParser(argparse.ArgumentParser):
batch_size: Create a flag to specify the batch size. batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs. multi_gpu: Create a flag to allow the use of all available GPUs.
hooks: Create a flag to specify hooks for logging. hooks: Create a flag to specify hooks for logging.
export_dir: Create a flag to specify where a SavedModel should be exported.
""" """
def __init__(self, add_help=False, data_dir=True, model_dir=True, def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, epochs_between_evals=True, train_epochs=True, epochs_between_evals=True,
stop_threshold=True, batch_size=True, multi_gpu=True, stop_threshold=True, batch_size=True, multi_gpu=True,
hooks=True): hooks=True, export_dir=True):
super(BaseParser, self).__init__(add_help=add_help) super(BaseParser, self).__init__(add_help=add_help)
if data_dir: if data_dir:
...@@ -176,6 +177,15 @@ class BaseParser(argparse.ArgumentParser): ...@@ -176,6 +177,15 @@ class BaseParser(argparse.ArgumentParser):
metavar="<HK>" metavar="<HK>"
) )
if export_dir:
self.add_argument(
"--export_dir", "-ed",
help="[default: %(default)s] If set, a SavedModel serialization of "
"the model will be exported to this directory at the end of "
"training. See the README for more details and relevant links.",
metavar="<ED>"
)
class PerformanceParser(argparse.ArgumentParser): class PerformanceParser(argparse.ArgumentParser):
"""Default parser for specifying performance tuning arguments. """Default parser for specifying performance tuning arguments.
...@@ -292,29 +302,6 @@ class ImageModelParser(argparse.ArgumentParser): ...@@ -292,29 +302,6 @@ class ImageModelParser(argparse.ArgumentParser):
) )
class ExportParser(argparse.ArgumentParser):
"""Parsing options for exporting saved models or other graph defs.
This is a separate parser for now, but should be made part of BaseParser
once all models are brought up to speed.
Args:
add_help: Create the "--help" flag. False if class instance is a parent.
export_dir: Create a flag to specify where a SavedModel should be exported.
"""
def __init__(self, add_help=False, export_dir=True):
super(ExportParser, self).__init__(add_help=add_help)
if export_dir:
self.add_argument(
"--export_dir", "-ed",
help="[default: %(default)s] If set, a SavedModel serialization of "
"the model will be exported to this directory at the end of "
"training. See the README for more details and relevant links.",
metavar="<ED>"
)
class BenchmarkParser(argparse.ArgumentParser): class BenchmarkParser(argparse.ArgumentParser):
"""Default parser for benchmark logging. """Default parser for benchmark logging.
......
...@@ -47,6 +47,37 @@ Run TensorBoard to inspect the details about the graph and training progression. ...@@ -47,6 +47,37 @@ Run TensorBoard to inspect the details about the graph and training progression.
tensorboard --logdir=/tmp/census_model tensorboard --logdir=/tmp/census_model
``` ```
## Inference with SavedModel
You can export the model into Tensorflow [SavedModel](https://www.tensorflow.org/programmers_guide/saved_model) format by using the argument `--export_dir`:
```
python wide_deep.py --export_dir /tmp/wide_deep_saved_model
```
After the model finishes training, use [`saved_model_cli`](https://www.tensorflow.org/programmers_guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel.
Try the following commands to inspect the SavedModel:
**Replace `${TIMESTAMP}` with the folder produced (e.g. 1524249124)**
```
# List possible tag_sets. Only one metagraph is saved, so there will be one option.
saved_model_cli show --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/
# Show SignatureDefs for tag_set=serve. SignatureDefs define the outputs to show.
saved_model_cli show --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/ \
--tag_set serve --all
```
### Inference
Let's use the model to predict the income group of two examples:
```
saved_model_cli run --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/ \
--tag_set serve --signature_def="predict" \
--input_examples='examples=[{"age":[46.], "education_num":[10.], "capital_gain":[7688.], "capital_loss":[0.], "hours_per_week":[38.]}, {"age":[24.], "education_num":[13.], "capital_gain":[0.], "capital_loss":[0.], "hours_per_week":[50.]}]'
```
This will print out the predicted classes and class probabilities. Class 0 is the <=50k group and 1 is the >50k group.
## Additional Links ## Additional Links
If you are interested in distributed training, take a look at [Distributed TensorFlow](https://www.tensorflow.org/deploy/distributed). If you are interested in distributed training, take a look at [Distributed TensorFlow](https://www.tensorflow.org/deploy/distributed).
......
...@@ -175,6 +175,27 @@ def input_fn(data_file, num_epochs, shuffle, batch_size): ...@@ -175,6 +175,27 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
return dataset return dataset
def export_model(model, model_type, export_dir):
"""Export to SavedModel format.
Args:
model: Estimator object
model_type: string indicating model type. "wide", "deep" or "wide_deep"
export_dir: directory to export the model.
"""
wide_columns, deep_columns = build_model_columns()
if model_type == 'wide':
columns = wide_columns
elif model_type == 'deep':
columns = deep_columns
else:
columns = wide_columns + deep_columns
feature_spec = tf.feature_column.make_parse_example_spec(columns)
example_input_fn = (
tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec))
model.export_savedmodel(export_dir, example_input_fn)
def main(argv): def main(argv):
parser = WideDeepArgParser() parser = WideDeepArgParser()
flags = parser.parse_args(args=argv[1:]) flags = parser.parse_args(args=argv[1:])
...@@ -216,6 +237,10 @@ def main(argv): ...@@ -216,6 +237,10 @@ def main(argv):
flags.stop_threshold, results['accuracy']): flags.stop_threshold, results['accuracy']):
break break
# Export the model
if flags.export_dir is not None:
export_model(model, flags.model_type, flags.export_dir)
class WideDeepArgParser(argparse.ArgumentParser): class WideDeepArgParser(argparse.ArgumentParser):
"""Argument parser for running the wide deep model.""" """Argument parser for running the wide deep model."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment