Unverified Commit 6c8b978d authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Add Additional flags to accuracy_checker.py (#1637)

Useful to get more insight into Onnxruntime. Allows us to reuse the accuracy checker code while also allowing us to capture Execution Provider output with the --ort_run and --ort_logging flags

Also added the --target flag as well to allow us to force using either a specific target for the accuracy checking. Originally this was defaulting to the GPU. This now allows us to use ref, fpga, etc to quickly change targets.
parent 09aaa63e
......@@ -63,8 +63,25 @@ def parse_args():
type=str,
action='append',
help='specify input parameter dimension \
with the following format --input_dim input_name:dim0,dim1,dim2...'
with the following format --input-dim input_name:dim0,dim1,dim2...'
)
parser.add_argument('--target',
type=str,
default='gpu',
help='target to compile and run MIGraphX on')
parser.add_argument('--ort-run',
dest="ort_run",
action='store_true',
default=False,
help='only perform an onnxruntime run')
parser.add_argument('--ort-logging',
dest="ort_logging",
action='store_true',
default=False,
help='Turn on ort VERBOSE logging via session options')
args = parser.parse_args()
return args
......@@ -111,7 +128,7 @@ def get_np_datatype(in_type):
'uint16_type': np.uint16,
'int8_type': np.int8,
'uint8_type': np.uint8,
'bool_type': np.bool_
'bool_type': bool
}
return datatypes[in_type]
......@@ -159,7 +176,8 @@ def main():
if args.verbose:
print(model)
model.compile(migraphx.get_target('gpu'))
if not args.ort_run:
model.compile(migraphx.get_target(args.target))
params = {}
test_inputs = {}
......@@ -178,10 +196,19 @@ def main():
test_inputs[name] = test_input
params[name] = migraphx.argument(test_input)
pred_migx = np.array(model.run(params)[-1])
if not args.ort_run:
pred_migx = np.array(model.run(params)[-1])
if use_onnx:
sess = ort.InferenceSession(model_name, providers=[args.provider])
sess_op = ort.SessionOptions()
if args.ort_logging:
sess_op.log_verbosity_level = 0
sess_op.log_severity_level = 0
sess = ort.InferenceSession(model_name,
sess_options=sess_op,
providers=[args.provider])
ort_params = {}
for input in sess.get_inputs():
......@@ -239,14 +266,15 @@ def main():
y_out = sess.run(y, feed_dict=tf_dict)
pred_fw = y_out
is_correct = check_correctness(pred_fw, pred_migx, args.tolerance,
args.tolerance, args.verbose)
verbose_string = ' Rerun with --verbose for detailed information.' \
if not args.verbose else ''
if is_correct:
print('PASSED: MIGraphX meets tolerance')
else:
print('FAILED: MIGraphX is not within tolerance.' + verbose_string)
if not args.ort_run:
is_correct = check_correctness(pred_fw, pred_migx, args.tolerance,
args.tolerance, args.verbose)
verbose_string = ' Rerun with --verbose for detailed information.' \
if not args.verbose else ''
if is_correct:
print('PASSED: MIGraphX meets tolerance')
else:
print('FAILED: MIGraphX is not within tolerance.' + verbose_string)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment