Unverified Commit 87a4dcc3 authored by Karmel Allison's avatar Karmel Allison Committed by GitHub
Browse files

Update unit for workspace size. (#4094)

parent de9f3584
...@@ -215,7 +215,7 @@ def get_tftrt_name(graph_name, precision_string): ...@@ -215,7 +215,7 @@ def get_tftrt_name(graph_name, precision_string):
def get_trt_graph(graph_name, graph_def, precision_mode, output_dir, def get_trt_graph(graph_name, graph_def, precision_mode, output_dir,
output_node, batch_size=128, workspace_size=1<<30): output_node, batch_size=128, workspace_size=2<<10):
"""Create and save inference graph using the TensorRT library. """Create and save inference graph using the TensorRT library.
Args: Args:
...@@ -227,14 +227,14 @@ def get_trt_graph(graph_name, graph_def, precision_mode, output_dir, ...@@ -227,14 +227,14 @@ def get_trt_graph(graph_name, graph_def, precision_mode, output_dir,
output_node: string, the names of the output node that will output_node: string, the names of the output node that will
be returned during inference. be returned during inference.
batch_size: int, the number of examples that will be predicted at a time. batch_size: int, the number of examples that will be predicted at a time.
workspace_size: long, size in bytes that can be used during conversion. workspace_size: int, size in megabytes that can be used during conversion.
Returns: Returns:
GraphDef for the TensorRT inference graph. GraphDef for the TensorRT inference graph.
""" """
trt_graph = trt.create_inference_graph( trt_graph = trt.create_inference_graph(
graph_def, [output_node], max_batch_size=batch_size, graph_def, [output_node], max_batch_size=batch_size,
max_workspace_size_bytes=workspace_size, max_workspace_size_bytes=workspace_size<<20,
precision_mode=precision_mode) precision_mode=precision_mode)
write_graph_to_file(graph_name, trt_graph, output_dir) write_graph_to_file(graph_name, trt_graph, output_dir)
...@@ -589,8 +589,8 @@ class TensorRTParser(argparse.ArgumentParser): ...@@ -589,8 +589,8 @@ class TensorRTParser(argparse.ArgumentParser):
) )
self.add_argument( self.add_argument(
"--workspace_size", "-ws", type=long, default=2<<30, "--workspace_size", "-ws", type=int, default=2<<10,
help="[default: %(default)s] Workspace size in bytes.", help="[default: %(default)s] Workspace size in megabytes.",
metavar="<WS>" metavar="<WS>"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment