Commit aa75f089 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 406303006
parent 2fec36b2
......@@ -20,11 +20,35 @@ import os
from absl.testing import parameterized
import tensorflow as tf
from official.core import exp_factory
from official.core import task_factory
from official.projects.edgetpu.vision.serving import export_util
def _build_experiment_model(experiment_type):
"""Builds model from experiment type configuration w/o loading checkpoint.
To reduce test latency and avoid unexpected errors (e.g. checkpoint files not
exist in the dedicated path), we skip the checkpoint loading for the tests.
Args:
experiment_type: model type for the experiment.
Returns:
TF/Keras model for the task.
"""
params = exp_factory.get_exp_config(experiment_type)
if 'deeplabv3plus_mobilenet_edgetpuv2' in experiment_type:
params.task.model.backbone.mobilenet_edgetpu.pretrained_checkpoint_path = None
if 'autoseg_edgetpu' in experiment_type:
params.task.model.model_params.model_weights_path = None
params.validate()
params.lock()
task = task_factory.get_task(params.task)
return task.build_model()
def _build_model(config):
model = export_util.build_experiment_model(config.model_name)
model = _build_experiment_model(config.model_name)
model_input = tf.keras.Input(
shape=(config.image_size, config.image_size, 3), batch_size=1)
model_output = export_util.finalize_serving(model(model_input), config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment