Commit ec71b64c authored by Umang Yadav's avatar Umang Yadav
Browse files

add tols for the driver verify

parent 076556e9
...@@ -536,13 +536,15 @@ struct params : command<params> ...@@ -536,13 +536,15 @@ struct params : command<params>
struct verify : command<verify> struct verify : command<verify>
{ {
compiler c; compiler c;
double rms_tol = 0.001; migraphx::verify::threshold tols;
bool per_instruction = false; bool per_instruction = false;
bool reduce = false; bool reduce = false;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
c.parse(ap); c.parse(ap);
ap(rms_tol, {"--rms_tol"}, ap.help("Tolerance for the RMS error")); ap(tols.rms_tol, {"--rms_tol"}, ap.help("Tolerance for the RMS error"));
ap(tols.atol, {"--atol"}, ap.help("Tolerance for the elementwise absolute error"));
ap(tols.rtol, {"--rtol"}, ap.help("Tolerance for the elementwise relative error"));
ap(per_instruction, ap(per_instruction,
{"-i", "--per-instruction"}, {"-i", "--per-instruction"},
ap.help("Verify each instruction"), ap.help("Verify each instruction"),
...@@ -567,15 +569,15 @@ struct verify : command<verify> ...@@ -567,15 +569,15 @@ struct verify : command<verify>
if(per_instruction) if(per_instruction)
{ {
verify_instructions(p, t, c.co, quantize, rms_tol); verify_instructions(p, t, c.co, quantize, tols);
} }
else if(reduce) else if(reduce)
{ {
verify_reduced_program(p, t, c.co, quantize, m, rms_tol); verify_reduced_program(p, t, c.co, quantize, m, tols);
} }
else else
{ {
verify_program(c.l.file, p, t, c.co, quantize, m, rms_tol); verify_program(c.l.file, p, t, c.co, quantize, m, tols);
} }
} }
}; };
......
...@@ -77,7 +77,7 @@ void verify_program(const std::string& name, ...@@ -77,7 +77,7 @@ void verify_program(const std::string& name,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double threshold) verify::threshold tols)
{ {
auto x = run_ref(p, inputs); auto x = run_ref(p, inputs);
auto y = run_target(p, t, options, quantize, inputs); auto y = run_target(p, t, options, quantize, inputs);
...@@ -85,7 +85,7 @@ void verify_program(const std::string& name, ...@@ -85,7 +85,7 @@ void verify_program(const std::string& name,
std::size_t output_num = x.size(); std::size_t output_num = x.size();
for(std::size_t i = 0; i < output_num; ++i) for(std::size_t i = 0; i < output_num; ++i)
{ {
verify_args_with_threshold(name, x[i], y[i], verify::threshold{threshold}); verify_args_with_threshold(name, x[i], y[i], tols);
} }
} }
...@@ -93,7 +93,7 @@ void verify_instructions(const program& prog, ...@@ -93,7 +93,7 @@ void verify_instructions(const program& prog,
const target& t, const target& t,
compile_options options, compile_options options,
precision quantize, precision quantize,
double threshold) verify::threshold tols)
{ {
const auto* mm_prog = prog.get_main_module(); const auto* mm_prog = prog.get_main_module();
for(auto&& ins : (*mm_prog)) for(auto&& ins : (*mm_prog))
...@@ -124,8 +124,7 @@ void verify_instructions(const program& prog, ...@@ -124,8 +124,7 @@ void verify_instructions(const program& prog,
{ {
std::cout << "Verify: " << ins.name() << std::endl; std::cout << "Verify: " << ins.name() << std::endl;
std::cout << p << std::endl; std::cout << p << std::endl;
verify_program( verify_program(ins.name(), p, t, options, quantize, create_param_map(p, false), tols);
ins.name(), p, t, options, quantize, create_param_map(p, false), threshold);
} }
catch(...) catch(...)
{ {
...@@ -141,14 +140,14 @@ void verify_reduced(program p, ...@@ -141,14 +140,14 @@ void verify_reduced(program p,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double threshold) verify::threshold tols)
{ {
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto last = std::prev(mm->end(), n + 1); auto last = std::prev(mm->end(), n + 1);
mm->remove_instructions(last, mm->end()); mm->remove_instructions(last, mm->end());
std::cout << "Verify: " << n << std::endl; std::cout << "Verify: " << n << std::endl;
std::cout << p << std::endl; std::cout << p << std::endl;
verify_program(std::to_string(n), p, t, options, quantize, inputs, threshold); verify_program(std::to_string(n), p, t, options, quantize, inputs, tols);
} }
void verify_reduced_program(const program& p, void verify_reduced_program(const program& p,
...@@ -156,14 +155,14 @@ void verify_reduced_program(const program& p, ...@@ -156,14 +155,14 @@ void verify_reduced_program(const program& p,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double threshold) verify::threshold tols)
{ {
const auto* mm = p.get_main_module(); const auto* mm = p.get_main_module();
auto n = std::distance(mm->begin(), mm->end()); auto n = std::distance(mm->begin(), mm->end());
std::cout << "Verify steps: " << n << std::endl; std::cout << "Verify steps: " << n << std::endl;
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
verify_reduced(p, i, t, options, quantize, inputs, threshold); verify_reduced(p, i, t, options, quantize, inputs, tols);
} }
} }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "precision.hpp" #include "precision.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/verify.hpp>
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
...@@ -37,18 +38,18 @@ void verify_program(const std::string& name, ...@@ -37,18 +38,18 @@ void verify_program(const std::string& name,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
const parameter_map& inputs = {}, const parameter_map& inputs = {},
double threshold = 0.001); verify::threshold tols = verify::threshold{});
void verify_instructions(const program& prog, void verify_instructions(const program& prog,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
double threshold = 0.001); verify::threshold tols = verify::threshold{});
void verify_reduced_program(const program& p, void verify_reduced_program(const program& p,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
const parameter_map& inputs = {}, const parameter_map& inputs = {},
double threshold = 0.001); verify::threshold tols = verify::threshold{});
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment