Commit ab9cb561 authored by Jinoo Baek's avatar Jinoo Baek Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 468296610
parent 5d6212c8
......@@ -136,21 +136,30 @@ def main(_):
if FLAGS.convert_tpu:
# pylint: disable=g-import-not-at-top
from cloud_tpu.inference_converter import converter_cli
from cloud_tpu.inference_converter import converter_options_pb2
from cloud_tpu.inference_converter_v2 import converter_options_v2_pb2
from cloud_tpu.inference_converter_v2.python import converter
tpu_dir = os.path.join(export_dir, "tpu")
options = converter_options_pb2.ConverterOptions()
batch_options = []
if FLAGS.allowed_batch_size is not None:
allowed_batch_sizes = sorted(FLAGS.allowed_batch_size)
options.batch_options.num_batch_threads = FLAGS.num_batch_threads
options.batch_options.max_batch_size = allowed_batch_sizes[-1]
options.batch_options.batch_timeout_micros = FLAGS.batch_timeout_micros
options.batch_options.allowed_batch_sizes[:] = allowed_batch_sizes
options.batch_options.max_enqueued_batches = FLAGS.max_enqueued_batches
converter_cli.ConvertSavedModel(
export_dir, tpu_dir, function_alias="tpu_candidate", options=options,
graph_rewrite_only=True)
batch_option = converter_options_v2_pb2.BatchOptionsV2(
num_batch_threads=FLAGS.num_batch_threads,
max_batch_size=allowed_batch_sizes[-1],
batch_timeout_micros=FLAGS.batch_timeout_micros,
allowed_batch_sizes=allowed_batch_sizes,
max_enqueued_batches=FLAGS.max_enqueued_batches
)
batch_options.append(batch_option)
converter_options = converter_options_v2_pb2.ConverterOptionsV2(
tpu_functions=[
converter_options_v2_pb2.TpuFunction(function_alias="tpu_candidate")
],
batch_options=batch_options,
)
converter.ConvertSavedModel(export_dir, tpu_dir, converter_options)
if __name__ == "__main__":
define_flags()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment