Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d11d9845
Commit
d11d9845
authored
May 21, 2019
by
tjakob
Committed by
Guangda Lai
May 21, 2019
Browse files
Use new tensorrt API (#6828)
parent
30d14a96
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
27 deletions
+42
-27
research/tensorrt/tensorrt.py
research/tensorrt/tensorrt.py
+42
-27
No files found.
research/tensorrt/tensorrt.py
View file @
d11d9845
...
...
@@ -31,7 +31,7 @@ import time
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.contrib.saved_model.python.saved_model
import
reader
import
tensorflow.
contrib.tensor
rt
as
trt
from
tensorflow.
python.compiler.tensorrt
import
trt_conve
rt
as
trt
from
official.resnet
import
imagenet_preprocessing
# pylint: disable=g-bad-import-order
...
...
@@ -212,38 +212,55 @@ def get_frozen_graph(graph_file):
def
get_tftrt_name
(
graph_name
,
precision_string
):
return
"tftrt_{}_{}"
.
format
(
precision_string
.
lower
(),
graph_name
)
def
get_trt_graph
(
graph_name
,
graph_def
,
precision_mode
,
output_dir
,
output_node
,
batch_size
=
128
,
workspace_size
=
2
<<
10
):
"""Create and save inference graph using the TensorRT library.
def
get_trt_converter
(
graph_def
,
precision_mode
,
output_node
,
batch_size
=
128
,
workspace_size
=
2
<<
10
):
""" Create a TrtGraphConverter Object to use later
Args:
graph_name: string, name of the graph to be used for saving.
graph_def: GraphDef, the Frozen Graph to be converted.
precision_mode: string, the precision that TensorRT should convert into.
Options- FP32, FP16, INT8.
output_dir: string, the path to where files should be written.
output_node: string, the names of the output node that will
be returned during inference.
batch_size: int, the number of examples that will be predicted at a time.
workspace_size: int, size in megabytes that can be used during conversion.
Returns:
Graph
Def for the TensorRT inference graph.
Trt
Graph
Converter Object
"""
trt_graph
=
trt
.
create_inference_graph
(
graph_def
,
[
output_node
],
max_batch_size
=
batch_size
,
max_workspace_size_bytes
=
workspace_size
<<
20
,
return
trt
.
TrtGraphConverter
(
input_graph_def
=
graph_def
,
nodes_blacklist
=
[
output_node
]
,
max_batch_size
=
batch_size
,
max_workspace_size_bytes
=
workspace_size
<<
20
,
precision_mode
=
precision_mode
)
def
get_trt_graph
(
graph_name
,
converter
,
output_dir
):
"""Create and save inference graph using the TensorRT library.
Args:
graph_name: string, name of the graph to be used for saving.
converter: TrtGraphConverter object representing the graphDef
output_dir: string, the path to where files should be written.
Returns:
GraphDef for the TensorRT inference graph.
"""
trt_graph
=
converter
.
convert
()
write_graph_to_file
(
graph_name
,
trt_graph
,
output_dir
)
return
trt_graph
def
get_trt_graph_from_calib
(
graph_name
,
calib_graph_def
,
output_dir
):
def
get_trt_graph_from_calib
(
graph_name
,
converter
,
data
,
input_node
,
output_node
,
output_dir
,
num_loops
=
100
):
"""Convert a TensorRT graph used for calibration to an inference graph."""
trt_graph
=
trt
.
calib_graph_to_infer_graph
(
calib_graph_def
)
converter
.
convert
()
def
input_fn
():
iterator
=
get_iterator
(
data
)
return
{
input_node
:
iterator
.
get_next
()}
trt_graph
=
converter
.
calibrate
(
fetch_names
=
[
output_node
],
num_runs
=
num_loops
,
input_map_fn
=
input_fn
)
write_graph_to_file
(
graph_name
,
trt_graph
,
output_dir
)
return
trt_graph
...
...
@@ -366,9 +383,9 @@ def run_trt_graph_for_mode(
graph_name
,
graph_def
,
mode
,
data
,
log_buffer
,
flags
):
"""Convert, time, and log the graph at `mode` precision using TensorRT."""
g_name
=
get_tftrt_name
(
graph_name
,
mode
)
graph
=
get_trt_
graph
(
g_name
,
graph_def
,
mode
,
flags
.
output_
dir
,
flags
.
output_node
,
flags
.
batch_size
,
flags
.
workspace_size
)
trt_converter
=
get_trt_
converter
(
graph_def
,
mode
,
flags
.
output_
node
,
flags
.
batch_size
,
flags
.
workspace_size
)
graph
=
get_trt_graph
(
g_name
,
trt_converter
,
flags
.
output_dir
)
result
=
time_and_log_graph
(
g_name
,
graph
,
data
,
log_buffer
,
flags
)
return
result
...
...
@@ -476,15 +493,13 @@ def main(argv):
if
flags
.
int8
:
mode
=
"INT8"
print
(
"Running {} graph"
.
format
(
mode
))
save_name
=
get_tftrt_name
(
graph_name
,
"INT8_calib"
)
calib_graph
=
get_trt_graph
(
save_name
,
frozen_graph_def
,
mode
,
flags
.
output_dir
,
flags
.
output_node
,
flags
.
batch_size
,
flags
.
workspace_size
)
time_graph
(
calib_graph
,
data
,
flags
.
input_node
,
flags
.
output_node
,
num_loops
=
1
)
trt_converter
=
get_trt_converter
(
frozen_graph_def
,
mode
,
flags
.
output_node
,
flags
.
batch_size
,
flags
.
workspace_size
)
g_name
=
get_tftrt_name
(
graph_name
,
mode
)
int8_graph
=
get_trt_graph_from_calib
(
g_name
,
calib_graph
,
flags
.
output_dir
)
int8_graph
=
get_trt_graph_from_calib
(
g_name
,
trt_converter
,
data
,
flags
.
input_node
,
flags
.
output_node
,
flags
.
output_dir
,
num_loops
=
1
)
result
=
time_and_log_graph
(
g_name
,
int8_graph
,
data
,
log_buffer
,
flags
)
results
.
append
((
mode
,
result
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment