Unverified Commit b249fb8a authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Add fp16 to accuracy checker (#2253)

parent 9a91cc25
##################################################################################### #####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
...@@ -52,6 +52,12 @@ def parse_args(): ...@@ -52,6 +52,12 @@ def parse_args():
parser.add_argument('--fill0', parser.add_argument('--fill0',
action='store_true', action='store_true',
help='fill all arguments with a value of 0') help='fill all arguments with a value of 0')
parser.add_argument('--fp16',
action='store_true',
help='quantize MIGraphX model to fp16')
parser.add_argument('--argmax',
action='store_true',
help='use argmax for accuracy')
parser.add_argument('--verbose', parser.add_argument('--verbose',
action='store_true', action='store_true',
help='show verbose information (for debugging)') help='show verbose information (for debugging)')
...@@ -105,7 +111,7 @@ def parse_args(): ...@@ -105,7 +111,7 @@ def parse_args():
args = parser.parse_args() args = parser.parse_args()
return args return args, parser
# taken from ../test_runner.py # taken from ../test_runner.py
...@@ -113,6 +119,7 @@ def check_correctness(gold_outputs, ...@@ -113,6 +119,7 @@ def check_correctness(gold_outputs,
outputs, outputs,
rtol=1e-3, rtol=1e-3,
atol=1e-3, atol=1e-3,
use_argmax=False,
verbose=False): verbose=False):
if len(gold_outputs) != len(outputs): if len(gold_outputs) != len(outputs):
print('Number of outputs {} is not equal to expected number {}'.format( print('Number of outputs {} is not equal to expected number {}'.format(
...@@ -121,6 +128,8 @@ def check_correctness(gold_outputs, ...@@ -121,6 +128,8 @@ def check_correctness(gold_outputs,
out_num = len(gold_outputs) out_num = len(gold_outputs)
ret = True ret = True
if not use_argmax:
for i in range(out_num): for i in range(out_num):
if not np.allclose(gold_outputs[i], outputs[i], rtol, atol): if not np.allclose(gold_outputs[i], outputs[i], rtol, atol):
ret = False ret = False
...@@ -132,7 +141,16 @@ def check_correctness(gold_outputs, ...@@ -132,7 +141,16 @@ def check_correctness(gold_outputs,
else: else:
print('Outputs do not match') print('Outputs do not match')
break break
else:
golden_argmax = np.argmax(gold_outputs)
actual_argmax = np.argmax(outputs)
if actual_argmax != golden_argmax:
ret = False
print('\nOutput argmax is incorrect ...')
if verbose:
print('Expected argmax value: \n{}'.format(golden_argmax))
print('......')
print('Actual argmax value: \n{}\n'.format(actual_argmax))
return ret return ret
...@@ -155,13 +173,14 @@ def get_np_datatype(in_type): ...@@ -155,13 +173,14 @@ def get_np_datatype(in_type):
def main(): def main():
args = parse_args() args, parser = parse_args()
use_onnx = True use_onnx = True
if args.onnx == None: if args.onnx == None:
use_onnx = False use_onnx = False
if not use_onnx and args.tf == None: if not use_onnx and args.tf == None:
print('Error: please specify either an onnx or tf pb file') print('Error: please specify either an onnx or tf pb file')
parser.print_help()
sys.exit(-1) sys.exit(-1)
model_name = args.onnx model_name = args.onnx
...@@ -194,6 +213,9 @@ def main(): ...@@ -194,6 +213,9 @@ def main():
batch_size=batch, batch_size=batch,
map_input_dims=input_dims) map_input_dims=input_dims)
if (args.fp16):
migraphx.quantize_fp16(model)
if args.verbose: if args.verbose:
print(model) print(model)
...@@ -300,7 +322,8 @@ def main(): ...@@ -300,7 +322,8 @@ def main():
if not args.ort_run: if not args.ort_run:
is_correct = check_correctness(pred_fw, pred_migx, args.tolerance, is_correct = check_correctness(pred_fw, pred_migx, args.tolerance,
args.tolerance, args.verbose) args.tolerance, args.argmax,
args.verbose)
verbose_string = ' Rerun with --verbose for detailed information.' \ verbose_string = ' Rerun with --verbose for detailed information.' \
if not args.verbose else '' if not args.verbose else ''
if is_correct: if is_correct:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment