"...test_cli/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "1b71bb9309de2857bc152b94138a2296c7df0e68"
Commit b81fe53a authored by Jinoo Baek's avatar Jinoo Baek Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 468296610
parent ceacd37c
...@@ -136,21 +136,30 @@ def main(_): ...@@ -136,21 +136,30 @@ def main(_):
if FLAGS.convert_tpu: if FLAGS.convert_tpu:
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
from cloud_tpu.inference_converter import converter_cli from cloud_tpu.inference_converter_v2 import converter_options_v2_pb2
from cloud_tpu.inference_converter import converter_options_pb2 from cloud_tpu.inference_converter_v2.python import converter
tpu_dir = os.path.join(export_dir, "tpu") tpu_dir = os.path.join(export_dir, "tpu")
options = converter_options_pb2.ConverterOptions() batch_options = []
if FLAGS.allowed_batch_size is not None: if FLAGS.allowed_batch_size is not None:
allowed_batch_sizes = sorted(FLAGS.allowed_batch_size) allowed_batch_sizes = sorted(FLAGS.allowed_batch_size)
options.batch_options.num_batch_threads = FLAGS.num_batch_threads batch_option = converter_options_v2_pb2.BatchOptionsV2(
options.batch_options.max_batch_size = allowed_batch_sizes[-1] num_batch_threads=FLAGS.num_batch_threads,
options.batch_options.batch_timeout_micros = FLAGS.batch_timeout_micros max_batch_size=allowed_batch_sizes[-1],
options.batch_options.allowed_batch_sizes[:] = allowed_batch_sizes batch_timeout_micros=FLAGS.batch_timeout_micros,
options.batch_options.max_enqueued_batches = FLAGS.max_enqueued_batches allowed_batch_sizes=allowed_batch_sizes,
converter_cli.ConvertSavedModel( max_enqueued_batches=FLAGS.max_enqueued_batches
export_dir, tpu_dir, function_alias="tpu_candidate", options=options, )
graph_rewrite_only=True) batch_options.append(batch_option)
converter_options = converter_options_v2_pb2.ConverterOptionsV2(
tpu_functions=[
converter_options_v2_pb2.TpuFunction(function_alias="tpu_candidate")
],
batch_options=batch_options,
)
converter.ConvertSavedModel(export_dir, tpu_dir, converter_options)
if __name__ == "__main__": if __name__ == "__main__":
define_flags() define_flags()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment