"docs/zh_CN/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "edd3f8aca8eacc79cd7847c4f5fe905fc88027b1"
Unverified Commit 74ba9649 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add fp16 flag to test runner to check models quantized to fp16 (#2182)

parent f149b619
...@@ -39,6 +39,15 @@ def parse_args(): ...@@ -39,6 +39,15 @@ def parse_args():
type=str, type=str,
default='gpu', default='gpu',
help='Specify where the tests execute (ref, gpu)') help='Specify where the tests execute (ref, gpu)')
parser.add_argument('--fp16', action='store_true', help='Quantize to fp16')
parser.add_argument('--atol',
type=float,
default=1e-3,
help='The absolute tolerance parameter')
parser.add_argument('--rtol',
type=float,
default=1e-3,
help='The relative tolerance parameter')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -257,6 +266,8 @@ def main(): ...@@ -257,6 +266,8 @@ def main():
# read and compile model # read and compile model
model = migraphx.parse_onnx(model_path_name, map_input_dims=param_shapes) model = migraphx.parse_onnx(model_path_name, map_input_dims=param_shapes)
if args.fp16:
migraphx.quantize_fp16(model)
model.compile(migraphx.get_target(target)) model.compile(migraphx.get_target(target))
# get test cases # get test cases
...@@ -279,7 +290,10 @@ def main(): ...@@ -279,7 +290,10 @@ def main():
output_data = run_one_case(model, input_data) output_data = run_one_case(model, input_data)
# check output correctness # check output correctness
ret = check_correctness(gold_outputs, output_data) ret = check_correctness(gold_outputs,
output_data,
atol=args.atol,
rtol=args.rtol)
if ret: if ret:
correct_num += 1 correct_num += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment