Unverified Commit 87a4dcc3 authored by Karmel Allison's avatar Karmel Allison Committed by GitHub
Browse files

Update unit for workspace size. (#4094)

parent de9f3584
......@@ -215,7 +215,7 @@ def get_tftrt_name(graph_name, precision_string):
def get_trt_graph(graph_name, graph_def, precision_mode, output_dir,
output_node, batch_size=128, workspace_size=1<<30):
output_node, batch_size=128, workspace_size=2<<10):
"""Create and save inference graph using the TensorRT library.
Args:
......@@ -227,14 +227,14 @@ def get_trt_graph(graph_name, graph_def, precision_mode, output_dir,
output_node: string, the names of the output node that will
be returned during inference.
batch_size: int, the number of examples that will be predicted at a time.
workspace_size: long, size in bytes that can be used during conversion.
workspace_size: int, size in megabytes that can be used during conversion.
Returns:
GraphDef for the TensorRT inference graph.
"""
trt_graph = trt.create_inference_graph(
graph_def, [output_node], max_batch_size=batch_size,
max_workspace_size_bytes=workspace_size,
max_workspace_size_bytes=workspace_size<<20,
precision_mode=precision_mode)
write_graph_to_file(graph_name, trt_graph, output_dir)
......@@ -589,8 +589,8 @@ class TensorRTParser(argparse.ArgumentParser):
)
self.add_argument(
"--workspace_size", "-ws", type=long, default=2<<30,
help="[default: %(default)s] Workspace size in bytes.",
"--workspace_size", "-ws", type=int, default=2<<10,
help="[default: %(default)s] Workspace size in megabytes.",
metavar="<WS>"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment