Commit d11d9845 authored by tjakob's avatar tjakob Committed by Guangda Lai
Browse files

Use new tensorrt API (#6828)

parent 30d14a96
...@@ -31,7 +31,7 @@ import time ...@@ -31,7 +31,7 @@ import time
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.saved_model.python.saved_model import reader from tensorflow.contrib.saved_model.python.saved_model import reader
import tensorflow.contrib.tensorrt as trt from tensorflow.python.compiler.tensorrt import trt_convert as trt
from official.resnet import imagenet_preprocessing # pylint: disable=g-bad-import-order from official.resnet import imagenet_preprocessing # pylint: disable=g-bad-import-order
...@@ -212,38 +212,55 @@ def get_frozen_graph(graph_file): ...@@ -212,38 +212,55 @@ def get_frozen_graph(graph_file):
def get_tftrt_name(graph_name, precision_string): def get_tftrt_name(graph_name, precision_string):
return "tftrt_{}_{}".format(precision_string.lower(), graph_name) return "tftrt_{}_{}".format(precision_string.lower(), graph_name)
def get_trt_converter(graph_def, precision_mode, output_node, batch_size=128,
def get_trt_graph(graph_name, graph_def, precision_mode, output_dir, workspace_size=2<<10):
output_node, batch_size=128, workspace_size=2<<10): """ Create a TrtGraphConverter Object to use later
"""Create and save inference graph using the TensorRT library.
Args: Args:
graph_name: string, name of the graph to be used for saving.
graph_def: GraphDef, the Frozen Graph to be converted. graph_def: GraphDef, the Frozen Graph to be converted.
precision_mode: string, the precision that TensorRT should convert into. precision_mode: string, the precision that TensorRT should convert into.
Options- FP32, FP16, INT8. Options- FP32, FP16, INT8.
output_dir: string, the path to where files should be written.
output_node: string, the names of the output node that will output_node: string, the names of the output node that will
be returned during inference. be returned during inference.
batch_size: int, the number of examples that will be predicted at a time. batch_size: int, the number of examples that will be predicted at a time.
workspace_size: int, size in megabytes that can be used during conversion. workspace_size: int, size in megabytes that can be used during conversion.
Returns: Returns:
GraphDef for the TensorRT inference graph. TrtGraphConverter Object
""" """
trt_graph = trt.create_inference_graph( return trt.TrtGraphConverter(
graph_def, [output_node], max_batch_size=batch_size, input_graph_def=graph_def, nodes_blacklist=[output_node],
max_workspace_size_bytes=workspace_size<<20, max_batch_size=batch_size, max_workspace_size_bytes=workspace_size<<20,
precision_mode=precision_mode) precision_mode=precision_mode)
def get_trt_graph(graph_name, converter, output_dir):
"""Create and save inference graph using the TensorRT library.
Args:
graph_name: string, name of the graph to be used for saving.
converter: TrtGraphConverter object representing the graphDef
output_dir: string, the path to where files should be written.
Returns:
GraphDef for the TensorRT inference graph.
"""
trt_graph = converter.convert()
write_graph_to_file(graph_name, trt_graph, output_dir) write_graph_to_file(graph_name, trt_graph, output_dir)
return trt_graph return trt_graph
def get_trt_graph_from_calib(graph_name, calib_graph_def, output_dir): def get_trt_graph_from_calib(graph_name, converter, data, input_node, output_node,
output_dir, num_loops=100):
"""Convert a TensorRT graph used for calibration to an inference graph.""" """Convert a TensorRT graph used for calibration to an inference graph."""
trt_graph = trt.calib_graph_to_infer_graph(calib_graph_def) converter.convert()
def input_fn():
iterator = get_iterator(data)
return {input_node: iterator.get_next()}
trt_graph = converter.calibrate(
fetch_names=[output_node],
num_runs=num_loops,
input_map_fn=input_fn)
write_graph_to_file(graph_name, trt_graph, output_dir) write_graph_to_file(graph_name, trt_graph, output_dir)
return trt_graph return trt_graph
...@@ -366,9 +383,9 @@ def run_trt_graph_for_mode( ...@@ -366,9 +383,9 @@ def run_trt_graph_for_mode(
graph_name, graph_def, mode, data, log_buffer, flags): graph_name, graph_def, mode, data, log_buffer, flags):
"""Convert, time, and log the graph at `mode` precision using TensorRT.""" """Convert, time, and log the graph at `mode` precision using TensorRT."""
g_name = get_tftrt_name(graph_name, mode) g_name = get_tftrt_name(graph_name, mode)
graph = get_trt_graph( trt_converter = get_trt_converter(
g_name, graph_def, mode, flags.output_dir, flags.output_node, graph_def, mode, flags.output_node, flags.batch_size, flags.workspace_size)
flags.batch_size, flags.workspace_size) graph = get_trt_graph(g_name, trt_converter, flags.output_dir)
result = time_and_log_graph(g_name, graph, data, log_buffer, flags) result = time_and_log_graph(g_name, graph, data, log_buffer, flags)
return result return result
...@@ -476,15 +493,13 @@ def main(argv): ...@@ -476,15 +493,13 @@ def main(argv):
if flags.int8: if flags.int8:
mode = "INT8" mode = "INT8"
print("Running {} graph".format(mode)) print("Running {} graph".format(mode))
save_name = get_tftrt_name(graph_name, "INT8_calib") trt_converter = get_trt_converter(
calib_graph = get_trt_graph( frozen_graph_def, mode, flags.output_node, flags.batch_size,
save_name, frozen_graph_def, mode, flags.output_dir, flags.output_node, flags.workspace_size)
flags.batch_size, flags.workspace_size)
time_graph(calib_graph, data, flags.input_node, flags.output_node,
num_loops=1)
g_name = get_tftrt_name(graph_name, mode) g_name = get_tftrt_name(graph_name, mode)
int8_graph = get_trt_graph_from_calib(g_name, calib_graph, flags.output_dir) int8_graph = get_trt_graph_from_calib(
g_name, trt_converter, data, flags.input_node, flags.output_node,
flags.output_dir, num_loops=1)
result = time_and_log_graph(g_name, int8_graph, data, log_buffer, flags) result = time_and_log_graph(g_name, int8_graph, data, log_buffer, flags)
results.append((mode, result)) results.append((mode, result))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment