"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f2756253e6874a5af8d22ec37462d1ce75d99c94"
Commit e888406e authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 409010753
parent ab1ffea0
...@@ -110,6 +110,7 @@ def get_export_config_from_flags(): ...@@ -110,6 +110,7 @@ def get_export_config_from_flags():
dataset_split=FLAGS.dataset_split) dataset_split=FLAGS.dataset_split)
export_config = export_util.ExportConfig( export_config = export_util.ExportConfig(
model_name=FLAGS.model_name, model_name=FLAGS.model_name,
output_layer=FLAGS.output_layer,
ckpt_path=FLAGS.ckpt_path, ckpt_path=FLAGS.ckpt_path,
ckpt_format=FLAGS.ckpt_format, ckpt_format=FLAGS.ckpt_format,
output_dir=FLAGS.output_dir, output_dir=FLAGS.output_dir,
......
...@@ -69,7 +69,9 @@ class ExportConfig(base_config.Config): ...@@ -69,7 +69,9 @@ class ExportConfig(base_config.Config):
"""Configuration for exporting models as tflite and saved_models. """Configuration for exporting models as tflite and saved_models.
Attributes: Attributes:
model_name: One of the registered model names model_name: One of the registered model names.
output_layer: Layer name to take the output from. Can be used to take the
output from an intermediate layer.
ckpt_path: Path of the training checkpoint. If not provided tflite with ckpt_path: Path of the training checkpoint. If not provided tflite with
random parameters is exported. random parameters is exported.
ckpt_format: Format of the checkpoint. tf_checkpoint is for ckpt files from ckpt_format: Format of the checkpoint. tf_checkpoint is for ckpt files from
...@@ -92,7 +94,8 @@ class ExportConfig(base_config.Config): ...@@ -92,7 +94,8 @@ class ExportConfig(base_config.Config):
resize bilinear to 128x128, then argmax then resize nn to 512x512 resize bilinear to 128x128, then argmax then resize nn to 512x512
""" """
quantization_config: QuantizationConfig = QuantizationConfig() quantization_config: QuantizationConfig = QuantizationConfig()
model_name: str = None model_name: Optional[str] = None
output_layer: Optional[str] = None
ckpt_path: Optional[str] = None ckpt_path: Optional[str] = None
ckpt_format: Optional[str] = 'tf_checkpoint' ckpt_format: Optional[str] = 'tf_checkpoint'
output_dir: str = '/tmp/' output_dir: str = '/tmp/'
......
...@@ -112,7 +112,6 @@ class EdgeTPUTask(base_task.Task): ...@@ -112,7 +112,6 @@ class EdgeTPUTask(base_task.Task):
else: else:
raise ValueError('Model has to be mobilenet-edgetpu model or searched' raise ValueError('Model has to be mobilenet-edgetpu model or searched'
'model with given saved model path.') 'model with given saved model path.')
model.summary()
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment