Unverified Commit 74ba9649 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add fp16 flag to test runner to check models quantized to fp16 (#2182)

parent f149b619
...@@ -39,6 +39,15 @@ def parse_args(): ...@@ -39,6 +39,15 @@ def parse_args():
type=str, type=str,
default='gpu', default='gpu',
help='Specify where the tests execute (ref, gpu)') help='Specify where the tests execute (ref, gpu)')
parser.add_argument('--fp16', action='store_true', help='Quantize to fp16')
parser.add_argument('--atol',
type=float,
default=1e-3,
help='The absolute tolerance parameter')
parser.add_argument('--rtol',
type=float,
default=1e-3,
help='The relative tolerance parameter')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -257,6 +266,8 @@ def main(): ...@@ -257,6 +266,8 @@ def main():
# read and compile model # read and compile model
model = migraphx.parse_onnx(model_path_name, map_input_dims=param_shapes) model = migraphx.parse_onnx(model_path_name, map_input_dims=param_shapes)
if args.fp16:
migraphx.quantize_fp16(model)
model.compile(migraphx.get_target(target)) model.compile(migraphx.get_target(target))
# get test cases # get test cases
...@@ -279,7 +290,10 @@ def main(): ...@@ -279,7 +290,10 @@ def main():
output_data = run_one_case(model, input_data) output_data = run_one_case(model, input_data)
# check output correctness # check output correctness
ret = check_correctness(gold_outputs, output_data) ret = check_correctness(gold_outputs,
output_data,
atol=args.atol,
rtol=args.rtol)
if ret: if ret:
correct_num += 1 correct_num += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment