Commit d11d9845 authored by tjakob's avatar tjakob Committed by Guangda Lai
Browse files

Use new tensorrt API (#6828)

parent 30d14a96
......@@ -31,7 +31,7 @@ import time
import numpy as np
import tensorflow as tf
from tensorflow.contrib.saved_model.python.saved_model import reader
import tensorflow.contrib.tensorrt as trt
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from official.resnet import imagenet_preprocessing # pylint: disable=g-bad-import-order
......@@ -211,39 +211,56 @@ def get_frozen_graph(graph_file):
def get_tftrt_name(graph_name, precision_string):
return "tftrt_{}_{}".format(precision_string.lower(), graph_name)
def get_trt_graph(graph_name, graph_def, precision_mode, output_dir,
output_node, batch_size=128, workspace_size=2<<10):
"""Create and save inference graph using the TensorRT library.
def get_trt_converter(graph_def, precision_mode, output_node, batch_size=128,
workspace_size=2<<10):
""" Create a TrtGraphConverter Object to use later
Args:
graph_name: string, name of the graph to be used for saving.
graph_def: GraphDef, the Frozen Graph to be converted.
precision_mode: string, the precision that TensorRT should convert into.
Options- FP32, FP16, INT8.
output_dir: string, the path to where files should be written.
output_node: string, the names of the output node that will
be returned during inference.
batch_size: int, the number of examples that will be predicted at a time.
workspace_size: int, size in megabytes that can be used during conversion.
Returns:
GraphDef for the TensorRT inference graph.
TrtGraphConverter Object
"""
trt_graph = trt.create_inference_graph(
graph_def, [output_node], max_batch_size=batch_size,
max_workspace_size_bytes=workspace_size<<20,
return trt.TrtGraphConverter(
input_graph_def=graph_def, nodes_blacklist=[output_node],
max_batch_size=batch_size, max_workspace_size_bytes=workspace_size<<20,
precision_mode=precision_mode)
def get_trt_graph(graph_name, converter, output_dir):
"""Create and save inference graph using the TensorRT library.
Args:
graph_name: string, name of the graph to be used for saving.
converter: TrtGraphConverter object representing the graphDef
output_dir: string, the path to where files should be written.
Returns:
GraphDef for the TensorRT inference graph.
"""
trt_graph = converter.convert()
write_graph_to_file(graph_name, trt_graph, output_dir)
return trt_graph
def get_trt_graph_from_calib(graph_name, calib_graph_def, output_dir):
def get_trt_graph_from_calib(graph_name, converter, data, input_node, output_node,
output_dir, num_loops=100):
"""Convert a TensorRT graph used for calibration to an inference graph."""
trt_graph = trt.calib_graph_to_infer_graph(calib_graph_def)
converter.convert()
def input_fn():
iterator = get_iterator(data)
return {input_node: iterator.get_next()}
trt_graph = converter.calibrate(
fetch_names=[output_node],
num_runs=num_loops,
input_map_fn=input_fn)
write_graph_to_file(graph_name, trt_graph, output_dir)
return trt_graph
......@@ -366,9 +383,9 @@ def run_trt_graph_for_mode(
graph_name, graph_def, mode, data, log_buffer, flags):
"""Convert, time, and log the graph at `mode` precision using TensorRT."""
g_name = get_tftrt_name(graph_name, mode)
graph = get_trt_graph(
g_name, graph_def, mode, flags.output_dir, flags.output_node,
flags.batch_size, flags.workspace_size)
trt_converter = get_trt_converter(
graph_def, mode, flags.output_node, flags.batch_size, flags.workspace_size)
graph = get_trt_graph(g_name, trt_converter, flags.output_dir)
result = time_and_log_graph(g_name, graph, data, log_buffer, flags)
return result
......@@ -476,15 +493,13 @@ def main(argv):
if flags.int8:
mode = "INT8"
print("Running {} graph".format(mode))
save_name = get_tftrt_name(graph_name, "INT8_calib")
calib_graph = get_trt_graph(
save_name, frozen_graph_def, mode, flags.output_dir, flags.output_node,
flags.batch_size, flags.workspace_size)
time_graph(calib_graph, data, flags.input_node, flags.output_node,
num_loops=1)
trt_converter = get_trt_converter(
frozen_graph_def, mode, flags.output_node, flags.batch_size,
flags.workspace_size)
g_name = get_tftrt_name(graph_name, mode)
int8_graph = get_trt_graph_from_calib(g_name, calib_graph, flags.output_dir)
int8_graph = get_trt_graph_from_calib(
g_name, trt_converter, data, flags.input_node, flags.output_node,
flags.output_dir, num_loops=1)
result = time_and_log_graph(g_name, int8_graph, data, log_buffer, flags)
results.append((mode, result))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment