Commit 31aed3ec authored by Umang Yadav's avatar Umang Yadav
Browse files

use threshold instead of tolerance

parent 482e8d61
...@@ -50,9 +50,9 @@ Runs reference and CPU or GPU implementations and checks outputs for consistency ...@@ -50,9 +50,9 @@ Runs reference and CPU or GPU implementations and checks outputs for consistency
.. include:: ./driver/compile.rst .. include:: ./driver/compile.rst
.. option:: --tolerance [double] .. option:: --threshold [double]
Tolerance for errors (Default: 80) Threshold for RMS error (Default: 0.001)
.. option:: -i, --per-instruction .. option:: -i, --per-instruction
......
...@@ -55,7 +55,7 @@ See below for a comprehensive list of commands and option arguments, as well as ...@@ -55,7 +55,7 @@ See below for a comprehensive list of commands and option arguments, as well as
| --exhaustive-tune | Enable exhaustive search to find fastest kernel | | --exhaustive-tune | Enable exhaustive search to find fastest kernel |
| --fp16 | Quantize for fp16 | | --fp16 | Quantize for fp16 |
| --int8 | Quantize for int8 | | --int8 | Quantize for int8 |
| --tolerance | Tolerance for errors | | --threshold | threshold for errors |
| --per-instruction \| -i | Verify each instruction | | --per-instruction \| -i | Verify each instruction |
| --reduce \| -r | Reduce program and verify | | --reduce \| -r | Reduce program and verify |
| --iterations \| -n | Number of iterations to run for perf report | | --iterations \| -n | Number of iterations to run for perf report |
......
...@@ -536,13 +536,13 @@ struct params : command<params> ...@@ -536,13 +536,13 @@ struct params : command<params>
struct verify : command<verify> struct verify : command<verify>
{ {
compiler c; compiler c;
double tolerance = 80; double threshold = 0.001;
bool per_instruction = false; bool per_instruction = false;
bool reduce = false; bool reduce = false;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
c.parse(ap); c.parse(ap);
ap(tolerance, {"--tolerance"}, ap.help("Tolerance for errors")); ap(threshold, {"--threshold"}, ap.help("threshold for the RMS error"));
ap(per_instruction, ap(per_instruction,
{"-i", "--per-instruction"}, {"-i", "--per-instruction"},
ap.help("Verify each instruction"), ap.help("Verify each instruction"),
...@@ -567,15 +567,15 @@ struct verify : command<verify> ...@@ -567,15 +567,15 @@ struct verify : command<verify>
if(per_instruction) if(per_instruction)
{ {
verify_instructions(p, t, c.co, quantize, tolerance); verify_instructions(p, t, c.co, quantize, threshold);
} }
else if(reduce) else if(reduce)
{ {
verify_reduced_program(p, t, c.co, quantize, m, tolerance); verify_reduced_program(p, t, c.co, quantize, m, threshold);
} }
else else
{ {
verify_program(c.l.file, p, t, c.co, quantize, m, tolerance); verify_program(c.l.file, p, t, c.co, quantize, m, threshold);
} }
} }
}; };
......
...@@ -76,7 +76,7 @@ void verify_program(const std::string& name, ...@@ -76,7 +76,7 @@ void verify_program(const std::string& name,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double tolerance) double threshold)
{ {
auto x = run_ref(p, inputs); auto x = run_ref(p, inputs);
auto y = run_target(p, t, options, quantize, inputs); auto y = run_target(p, t, options, quantize, inputs);
...@@ -84,7 +84,7 @@ void verify_program(const std::string& name, ...@@ -84,7 +84,7 @@ void verify_program(const std::string& name,
std::size_t output_num = x.size(); std::size_t output_num = x.size();
for(std::size_t i = 0; i < output_num; ++i) for(std::size_t i = 0; i < output_num; ++i)
{ {
verify_args(name, x[i], y[i], tolerance); verify_args(name, x[i], y[i], threshold);
} }
} }
...@@ -92,7 +92,7 @@ void verify_instructions(const program& prog, ...@@ -92,7 +92,7 @@ void verify_instructions(const program& prog,
const target& t, const target& t,
compile_options options, compile_options options,
precision quantize, precision quantize,
double tolerance) double threshold)
{ {
const auto* mm_prog = prog.get_main_module(); const auto* mm_prog = prog.get_main_module();
for(auto&& ins : (*mm_prog)) for(auto&& ins : (*mm_prog))
...@@ -124,7 +124,7 @@ void verify_instructions(const program& prog, ...@@ -124,7 +124,7 @@ void verify_instructions(const program& prog,
std::cout << "Verify: " << ins.name() << std::endl; std::cout << "Verify: " << ins.name() << std::endl;
std::cout << p << std::endl; std::cout << p << std::endl;
verify_program( verify_program(
ins.name(), p, t, options, quantize, create_param_map(p, false), tolerance); ins.name(), p, t, options, quantize, create_param_map(p, false), threshold);
} }
catch(...) catch(...)
{ {
...@@ -140,14 +140,14 @@ void verify_reduced(program p, ...@@ -140,14 +140,14 @@ void verify_reduced(program p,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double tolerance) double threshold)
{ {
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto last = std::prev(mm->end(), n + 1); auto last = std::prev(mm->end(), n + 1);
mm->remove_instructions(last, mm->end()); mm->remove_instructions(last, mm->end());
std::cout << "Verify: " << n << std::endl; std::cout << "Verify: " << n << std::endl;
std::cout << p << std::endl; std::cout << p << std::endl;
verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance); verify_program(std::to_string(n), p, t, options, quantize, inputs, threshold);
} }
void verify_reduced_program(const program& p, void verify_reduced_program(const program& p,
...@@ -155,14 +155,14 @@ void verify_reduced_program(const program& p, ...@@ -155,14 +155,14 @@ void verify_reduced_program(const program& p,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double tolerance) double threshold)
{ {
const auto* mm = p.get_main_module(); const auto* mm = p.get_main_module();
auto n = std::distance(mm->begin(), mm->end()); auto n = std::distance(mm->begin(), mm->end());
std::cout << "Verify steps: " << n << std::endl; std::cout << "Verify steps: " << n << std::endl;
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
verify_reduced(p, i, t, options, quantize, inputs, tolerance); verify_reduced(p, i, t, options, quantize, inputs, threshold);
} }
} }
......
...@@ -37,18 +37,18 @@ void verify_program(const std::string& name, ...@@ -37,18 +37,18 @@ void verify_program(const std::string& name,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
const parameter_map& inputs = {}, const parameter_map& inputs = {},
double tolerance = 100); double threshold = 0.001);
void verify_instructions(const program& prog, void verify_instructions(const program& prog,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
double tolerance = 80); double threshold = 0.001);
void verify_reduced_program(const program& p, void verify_reduced_program(const program& p,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
const parameter_map& inputs = {}, const parameter_map& inputs = {},
double tolerance = 80); double threshold = 0.001);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
......
...@@ -187,11 +187,32 @@ double rms_range(const R1& r1, const R2& r2) ...@@ -187,11 +187,32 @@ double rms_range(const R1& r1, const R2& r2)
return std::numeric_limits<range_value<R1>>::max(); return std::numeric_limits<range_value<R1>>::max();
} }
template <class R>
double get_threshold(const R&, std::size_t tolerance = 80)
{
double threshold = std::numeric_limits<range_value<R>>::epsilon() * tolerance;
return threshold;
}
template <class R1, class R2> template <class R1, class R2>
bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out_error = nullptr) bool verify_range(const R1& r1,
const R2& r2,
std::size_t tolerance = 80,
double* out_error = nullptr)
{ {
// double threshold = get_threshold(r1, tolerance);
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance; double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance;
auto error = rms_range(r1, r2);
auto error = rms_range(r1, r2);
if(out_error != nullptr)
*out_error = error;
return error <= threshold;
}
template <class R1, class R2>
bool verify_range(const R1& r1, const R2& r2, double threshold, double* out_error = nullptr)
{
auto error = rms_range(r1, r2);
if(out_error != nullptr) if(out_error != nullptr)
*out_error = error; *out_error = error;
return error <= threshold; return error <= threshold;
......
...@@ -31,11 +31,15 @@ ...@@ -31,11 +31,15 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_EXPORT MIGRAPHX_EXPORT bool verify_args(const std::string& name,
bool verify_args(const std::string& name, const argument& ref_arg,
const argument& ref_arg, const argument& target_arg,
const argument& target_arg, double threshold);
double tolerance = 80);
MIGRAPHX_EXPORT bool verify_args(const std::string& name,
const argument& ref_arg,
const argument& target_arg,
std::size_t tolerance = 80);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -30,12 +30,12 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -30,12 +30,12 @@ inline namespace MIGRAPHX_INLINE_NS {
bool verify_args(const std::string& name, bool verify_args(const std::string& name,
const argument& ref_arg, const argument& ref_arg,
const argument& target_arg, const argument& target_arg,
double tolerance) double threshold)
{ {
bool passed = true; bool passed = true;
visit_all(ref_arg, target_arg)([&](auto ref, auto target) { visit_all(ref_arg, target_arg)([&](auto ref, auto target) {
double error; double error;
passed = verify::verify_range(ref, target, tolerance, &error); passed = verify::verify_range(ref, target, threshold, &error);
if(not passed) if(not passed)
{ {
// TODO: Check for nans // TODO: Check for nans
...@@ -93,5 +93,15 @@ bool verify_args(const std::string& name, ...@@ -93,5 +93,15 @@ bool verify_args(const std::string& name,
return passed; return passed;
} }
bool verify_args(const std::string& name,
const argument& ref_arg,
const argument& target_arg,
std::size_t tolerance)
{
double threshold = 0.001;
target_arg.visit([&](auto ta) { threshold = verify::get_threshold(ta, tolerance); });
return verify_args(name, ref_arg, target_arg, threshold);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -118,7 +118,7 @@ TEST_CASE(int8_quantization) ...@@ -118,7 +118,7 @@ TEST_CASE(int8_quantization)
// the regular pipeline uses the rewrite_quantization in the much // the regular pipeline uses the rewrite_quantization in the much
// earlier stage. // earlier stage.
if(migraphx::gpu::mlir_enabled()) if(migraphx::gpu::mlir_enabled())
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result, 1e5)); EXPECT(migraphx::verify::verify_range(ref_result, gpu_result, 0.0119209));
else else
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result)); EXPECT(migraphx::verify::verify_range(ref_result, gpu_result));
} }
......
...@@ -83,7 +83,7 @@ TEST_CASE(param_add) ...@@ -83,7 +83,7 @@ TEST_CASE(param_add)
auto hs = mm->add_instruction(migraphx::make_op("add"), hp1, hp2); auto hs = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
auto fs = mm->add_instruction( auto fs = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
hs); hs);
if(add_return) if(add_return)
{ {
...@@ -1077,7 +1077,7 @@ TEST_CASE(int8_quantization_dot) ...@@ -1077,7 +1077,7 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result); run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result, 30000)); EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result, 0.003576));
} }
} }
......
...@@ -78,5 +78,5 @@ TEST_CASE(multinomial_test) ...@@ -78,5 +78,5 @@ TEST_CASE(multinomial_test)
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) { std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum; return static_cast<double>(n) / res_dist_sum;
}); });
EXPECT(migraphx::verify::verify_range(norm, res_norm, 100000)); EXPECT(migraphx::verify::verify_range(norm, res_norm, 0.01));
} }
...@@ -68,7 +68,7 @@ TEST_CASE(random_uniform_test) ...@@ -68,7 +68,7 @@ TEST_CASE(random_uniform_test)
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size); std::vector<float> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_range(result_vec, rand_samples, 100)); EXPECT(migraphx::verify::verify_range(result_vec, rand_samples, 0.00001));
} }
TEST_CASE(random_uniform_int_test) TEST_CASE(random_uniform_int_test)
......
...@@ -1008,7 +1008,7 @@ TEST_CASE(rnn_fp16) ...@@ -1008,7 +1008,7 @@ TEST_CASE(rnn_fp16)
std::vector<float> last_output_data_gold{ std::vector<float> last_output_data_gold{
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.}; 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold, 5e4)); EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold, 0.005));
} }
TEST_CASE(gru_forward) TEST_CASE(gru_forward)
...@@ -2983,7 +2983,7 @@ TEST_CASE(gru_fp16) ...@@ -2983,7 +2983,7 @@ TEST_CASE(gru_fp16)
-0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873, -0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873,
-0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427}; -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold, 5e4)); EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold, 0.00596));
} }
TEST_CASE(lstm_forward) TEST_CASE(lstm_forward)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment