Commit f3a8933c authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into blas_tuning

parents ca300bd6 b249fb8a
...@@ -31,10 +31,13 @@ ...@@ -31,10 +31,13 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <test.hpp> #include <test.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/env.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK_WORKAROUNDS);
bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; } bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; }
bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; } bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; }
bool is_clip_scalar(migraphx::instruction& ins) bool is_clip_scalar(migraphx::instruction& ins)
...@@ -82,7 +85,11 @@ TEST_CASE(quantizelinear) ...@@ -82,7 +85,11 @@ TEST_CASE(quantizelinear)
EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear)); EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear));
EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear)); EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear));
// ensure clip literals created in quantized program are scalar // ensure clip literals created in quantized program are scalar
EXPECT(any_of(*p2.get_main_module(), &is_clip_scalar)); // unless CK workarounds are enabled
if(migraphx::enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
EXPECT(none_of(*p2.get_main_module(), &is_clip_scalar));
else
EXPECT(any_of(*p2.get_main_module(), &is_clip_scalar));
} }
TEST_CASE(dequantizelinear) TEST_CASE(dequantizelinear)
......
...@@ -41,11 +41,7 @@ TEST_CASE(make_invalid_target) ...@@ -41,11 +41,7 @@ TEST_CASE(make_invalid_target)
TEST_CASE(targets) TEST_CASE(targets)
{ {
// GCC doesn't load libmigraphx_ref unless necesssary even though it is linked to the test.
// Force it to load by making ref target
#if defined(__GNUC__) && !defined(__clang__)
auto ref_target = migraphx::make_target("ref"); auto ref_target = migraphx::make_target("ref");
#endif
auto ts = migraphx::get_targets(); auto ts = migraphx::get_targets();
EXPECT(ts.size() >= 1); EXPECT(ts.size() >= 1);
} }
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
struct quant_conv_default_mode : verify_program<quant_conv_default_mode> struct quant_conv_1 : verify_program<quant_conv_1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
struct quant_conv_int8x4_default : verify_program<quant_conv_int8x4_default> struct quant_conv_2 : verify_program<quant_conv_2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -29,8 +29,8 @@ ...@@ -29,8 +29,8 @@
#include <migraphx/op/argmax.hpp> #include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp> #include <migraphx/op/argmin.hpp>
template <class T, int Axis, int NonStdShape> template <class T, int Axis, bool LastIndex, int NonStdShape>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>> struct test_arg_ops : verify_program<test_arg_ops<T, Axis, LastIndex, NonStdShape>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -54,63 +54,111 @@ struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>> ...@@ -54,63 +54,111 @@ struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>>
break; break;
default: break; default: break;
} }
mm->add_instruction(T{Axis}, param); mm->add_instruction(T{Axis, LastIndex}, param);
return p; return p;
} }
}; };
// transpose argmax tests // transpose argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 0>; template struct test_arg_ops<migraphx::op::argmax, 0, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1, 0>; template struct test_arg_ops<migraphx::op::argmax, 0, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, 2, 0>; template struct test_arg_ops<migraphx::op::argmax, 1, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, 3, 0>; template struct test_arg_ops<migraphx::op::argmax, 1, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, -1, 0>; template struct test_arg_ops<migraphx::op::argmax, 2, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, -2, 0>; template struct test_arg_ops<migraphx::op::argmax, 2, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, 3, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, 3, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, -1, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, -1, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, -2, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, -2, false, 0>;
// transpose argmin tests // transpose argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 0>; template struct test_arg_ops<migraphx::op::argmin, 0, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1, 0>; template struct test_arg_ops<migraphx::op::argmin, 0, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, 2, 0>; template struct test_arg_ops<migraphx::op::argmin, 1, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, 3, 0>; template struct test_arg_ops<migraphx::op::argmin, 1, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, -3, 0>; template struct test_arg_ops<migraphx::op::argmin, 2, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, -4, 0>; template struct test_arg_ops<migraphx::op::argmin, 2, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, 3, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, 3, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, -3, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, -3, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, -4, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, -4, false, 0>;
// broadcast argmax tests // broadcast argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 1>; template struct test_arg_ops<migraphx::op::argmax, 0, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, 1, 1>; template struct test_arg_ops<migraphx::op::argmax, 0, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2, 1>; template struct test_arg_ops<migraphx::op::argmax, 1, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, 3, 1>; template struct test_arg_ops<migraphx::op::argmax, 1, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, -1, 1>; template struct test_arg_ops<migraphx::op::argmax, 2, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, -2, 1>; template struct test_arg_ops<migraphx::op::argmax, 2, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, 3, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, 3, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, -1, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, -1, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, -2, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, -2, false, 1>;
// broadcast argmin tests // broadcast argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 1>; template struct test_arg_ops<migraphx::op::argmin, 0, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, 1, 1>; template struct test_arg_ops<migraphx::op::argmin, 0, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2, 1>; template struct test_arg_ops<migraphx::op::argmin, 1, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, 3, 1>; template struct test_arg_ops<migraphx::op::argmin, 1, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, -3, 1>; template struct test_arg_ops<migraphx::op::argmin, 2, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, -4, 1>; template struct test_arg_ops<migraphx::op::argmin, 2, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, 3, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, 3, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, -3, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, -3, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, -4, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, -4, false, 1>;
// slice argmax tests // slice argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 2>; template struct test_arg_ops<migraphx::op::argmax, 0, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, 1, 2>; template struct test_arg_ops<migraphx::op::argmax, 0, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, 2, 2>; template struct test_arg_ops<migraphx::op::argmax, 1, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3, 2>; template struct test_arg_ops<migraphx::op::argmax, 1, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, -1, 2>; template struct test_arg_ops<migraphx::op::argmax, 2, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, -2, 2>; template struct test_arg_ops<migraphx::op::argmax, 2, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, -1, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, -1, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, -2, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, -2, false, 2>;
// slice argmin tests // slice argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 2>; template struct test_arg_ops<migraphx::op::argmin, 0, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, 1, 2>; template struct test_arg_ops<migraphx::op::argmin, 0, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, 2, 2>; template struct test_arg_ops<migraphx::op::argmin, 1, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3, 2>; template struct test_arg_ops<migraphx::op::argmin, 1, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, -3, 2>; template struct test_arg_ops<migraphx::op::argmin, 2, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, -4, 2>; template struct test_arg_ops<migraphx::op::argmin, 2, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, -3, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, -3, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, -4, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, -4, false, 2>;
// default case, standard shape argmax tests // default case, standard shape argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 3>; template struct test_arg_ops<migraphx::op::argmax, 0, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, 1, 3>; template struct test_arg_ops<migraphx::op::argmax, 0, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, 2, 3>; template struct test_arg_ops<migraphx::op::argmax, 1, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, 3, 3>; template struct test_arg_ops<migraphx::op::argmax, 1, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1, 3>; template struct test_arg_ops<migraphx::op::argmax, 2, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, -2, 3>; template struct test_arg_ops<migraphx::op::argmax, 2, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, 3, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, 3, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, -2, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, -2, false, 3>;
// default case, standard shape argmin tests // default case, standard shape argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 3>; template struct test_arg_ops<migraphx::op::argmin, 0, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, 1, 3>; template struct test_arg_ops<migraphx::op::argmin, 0, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, 2, 3>; template struct test_arg_ops<migraphx::op::argmin, 1, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, 3, 3>; template struct test_arg_ops<migraphx::op::argmin, 1, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3, 3>; template struct test_arg_ops<migraphx::op::argmin, 2, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, -4, 3>; template struct test_arg_ops<migraphx::op::argmin, 2, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, 3, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, 3, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, -4, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, -4, false, 3>;
##################################################################################### #####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
...@@ -52,6 +52,12 @@ def parse_args(): ...@@ -52,6 +52,12 @@ def parse_args():
parser.add_argument('--fill0', parser.add_argument('--fill0',
action='store_true', action='store_true',
help='fill all arguments with a value of 0') help='fill all arguments with a value of 0')
parser.add_argument('--fp16',
action='store_true',
help='quantize MIGraphX model to fp16')
parser.add_argument('--argmax',
action='store_true',
help='use argmax for accuracy')
parser.add_argument('--verbose', parser.add_argument('--verbose',
action='store_true', action='store_true',
help='show verbose information (for debugging)') help='show verbose information (for debugging)')
...@@ -105,7 +111,7 @@ def parse_args(): ...@@ -105,7 +111,7 @@ def parse_args():
args = parser.parse_args() args = parser.parse_args()
return args return args, parser
# taken from ../test_runner.py # taken from ../test_runner.py
...@@ -113,6 +119,7 @@ def check_correctness(gold_outputs, ...@@ -113,6 +119,7 @@ def check_correctness(gold_outputs,
outputs, outputs,
rtol=1e-3, rtol=1e-3,
atol=1e-3, atol=1e-3,
use_argmax=False,
verbose=False): verbose=False):
if len(gold_outputs) != len(outputs): if len(gold_outputs) != len(outputs):
print('Number of outputs {} is not equal to expected number {}'.format( print('Number of outputs {} is not equal to expected number {}'.format(
...@@ -121,18 +128,29 @@ def check_correctness(gold_outputs, ...@@ -121,18 +128,29 @@ def check_correctness(gold_outputs,
out_num = len(gold_outputs) out_num = len(gold_outputs)
ret = True ret = True
for i in range(out_num):
if not np.allclose(gold_outputs[i], outputs[i], rtol, atol): if not use_argmax:
for i in range(out_num):
if not np.allclose(gold_outputs[i], outputs[i], rtol, atol):
ret = False
if verbose:
print('\nOutput {} is incorrect ...'.format(i))
print('Expected value: \n{}'.format(gold_outputs[i]))
print('......')
print('Actual value: \n{}\n'.format(outputs[i]))
else:
print('Outputs do not match')
break
else:
golden_argmax = np.argmax(gold_outputs)
actual_argmax = np.argmax(outputs)
if actual_argmax != golden_argmax:
ret = False ret = False
print('\nOutput argmax is incorrect ...')
if verbose: if verbose:
print('\nOutput {} is incorrect ...'.format(i)) print('Expected argmax value: \n{}'.format(golden_argmax))
print('Expected value: \n{}'.format(gold_outputs[i]))
print('......') print('......')
print('Actual value: \n{}\n'.format(outputs[i])) print('Actual argmax value: \n{}\n'.format(actual_argmax))
else:
print('Outputs do not match')
break
return ret return ret
...@@ -155,13 +173,14 @@ def get_np_datatype(in_type): ...@@ -155,13 +173,14 @@ def get_np_datatype(in_type):
def main(): def main():
args = parse_args() args, parser = parse_args()
use_onnx = True use_onnx = True
if args.onnx == None: if args.onnx == None:
use_onnx = False use_onnx = False
if not use_onnx and args.tf == None: if not use_onnx and args.tf == None:
print('Error: please specify either an onnx or tf pb file') print('Error: please specify either an onnx or tf pb file')
parser.print_help()
sys.exit(-1) sys.exit(-1)
model_name = args.onnx model_name = args.onnx
...@@ -194,6 +213,9 @@ def main(): ...@@ -194,6 +213,9 @@ def main():
batch_size=batch, batch_size=batch,
map_input_dims=input_dims) map_input_dims=input_dims)
if (args.fp16):
migraphx.quantize_fp16(model)
if args.verbose: if args.verbose:
print(model) print(model)
...@@ -300,7 +322,8 @@ def main(): ...@@ -300,7 +322,8 @@ def main():
if not args.ort_run: if not args.ort_run:
is_correct = check_correctness(pred_fw, pred_migx, args.tolerance, is_correct = check_correctness(pred_fw, pred_migx, args.tolerance,
args.tolerance, args.verbose) args.tolerance, args.argmax,
args.verbose)
verbose_string = ' Rerun with --verbose for detailed information.' \ verbose_string = ' Rerun with --verbose for detailed information.' \
if not args.verbose else '' if not args.verbose else ''
if is_correct: if is_correct:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment