"...composable_kernel_rocm.git" did not exist on "02153e24319b3e73c40c645376ec59d5226f0abd"
Unverified Commit 6c8b978d authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Add Additional flags to accuracy_checker.py (#1637)

Useful to get more insight into Onnxruntime. Allows us to reuse the accuracy checker code while also allowing us to capture Execution Provider output with the --ort_run and --ort_logging flags

Also added the --target flag as well to allow us to force using either a specific target for the accuracy checking. Originally this was defaulting to the GPU. This now allows us to use ref, fpga, etc to quickly change targets.
parent 09aaa63e
...@@ -63,8 +63,25 @@ def parse_args(): ...@@ -63,8 +63,25 @@ def parse_args():
type=str, type=str,
action='append', action='append',
help='specify input parameter dimension \ help='specify input parameter dimension \
with the following format --input_dim input_name:dim0,dim1,dim2...' with the following format --input-dim input_name:dim0,dim1,dim2...'
) )
parser.add_argument('--target',
type=str,
default='gpu',
help='target to compile and run MIGraphX on')
parser.add_argument('--ort-run',
dest="ort_run",
action='store_true',
default=False,
help='only perform an onnxruntime run')
parser.add_argument('--ort-logging',
dest="ort_logging",
action='store_true',
default=False,
help='Turn on ort VERBOSE logging via session options')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -111,7 +128,7 @@ def get_np_datatype(in_type): ...@@ -111,7 +128,7 @@ def get_np_datatype(in_type):
'uint16_type': np.uint16, 'uint16_type': np.uint16,
'int8_type': np.int8, 'int8_type': np.int8,
'uint8_type': np.uint8, 'uint8_type': np.uint8,
'bool_type': np.bool_ 'bool_type': bool
} }
return datatypes[in_type] return datatypes[in_type]
...@@ -159,7 +176,8 @@ def main(): ...@@ -159,7 +176,8 @@ def main():
if args.verbose: if args.verbose:
print(model) print(model)
model.compile(migraphx.get_target('gpu')) if not args.ort_run:
model.compile(migraphx.get_target(args.target))
params = {} params = {}
test_inputs = {} test_inputs = {}
...@@ -178,10 +196,19 @@ def main(): ...@@ -178,10 +196,19 @@ def main():
test_inputs[name] = test_input test_inputs[name] = test_input
params[name] = migraphx.argument(test_input) params[name] = migraphx.argument(test_input)
pred_migx = np.array(model.run(params)[-1]) if not args.ort_run:
pred_migx = np.array(model.run(params)[-1])
if use_onnx: if use_onnx:
sess = ort.InferenceSession(model_name, providers=[args.provider]) sess_op = ort.SessionOptions()
if args.ort_logging:
sess_op.log_verbosity_level = 0
sess_op.log_severity_level = 0
sess = ort.InferenceSession(model_name,
sess_options=sess_op,
providers=[args.provider])
ort_params = {} ort_params = {}
for input in sess.get_inputs(): for input in sess.get_inputs():
...@@ -239,14 +266,15 @@ def main(): ...@@ -239,14 +266,15 @@ def main():
y_out = sess.run(y, feed_dict=tf_dict) y_out = sess.run(y, feed_dict=tf_dict)
pred_fw = y_out pred_fw = y_out
is_correct = check_correctness(pred_fw, pred_migx, args.tolerance, if not args.ort_run:
args.tolerance, args.verbose) is_correct = check_correctness(pred_fw, pred_migx, args.tolerance,
verbose_string = ' Rerun with --verbose for detailed information.' \ args.tolerance, args.verbose)
if not args.verbose else '' verbose_string = ' Rerun with --verbose for detailed information.' \
if is_correct: if not args.verbose else ''
print('PASSED: MIGraphX meets tolerance') if is_correct:
else: print('PASSED: MIGraphX meets tolerance')
print('FAILED: MIGraphX is not within tolerance.' + verbose_string) else:
print('FAILED: MIGraphX is not within tolerance.' + verbose_string)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment