Commit f445d962 authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Add flags to driver to run quantization (#361)

* Add flags to quantize in driver

* Formatting

* Fix compile error
parent ef5e7ce0
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
...@@ -79,32 +80,47 @@ struct loader ...@@ -79,32 +80,47 @@ struct loader
struct compiler struct compiler
{ {
static const int q_fp16 = 1;
static const int q_int8 = 2;
loader l; loader l;
bool gpu = true; bool gpu = true;
int quantize = 0;
std::vector<std::string> fill1; std::vector<std::string> fill1;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
l.parse(ap); l.parse(ap);
ap(gpu, {"--gpu"}, ap.help("Compile on the gpu"), ap.set_value(true)); ap(gpu, {"--gpu"}, ap.help("Compile on the gpu"), ap.set_value(true));
ap(gpu, {"--cpu"}, ap.help("Compile on the cpu"), ap.set_value(false)); ap(gpu, {"--cpu"}, ap.help("Compile on the cpu"), ap.set_value(false));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(q_fp16));
ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(q_int8));
ap(fill1, {"--fill1"}, ap.help("Fill parameter with 1s"), ap.append()); ap(fill1, {"--fill1"}, ap.help("Fill parameter with 1s"), ap.append());
} }
program compile() auto params(const program& p, bool use_gpu = true)
{
auto p = l.load();
compile_program(p, gpu);
return p;
}
auto params(const program& p)
{ {
program::parameter_map m; program::parameter_map m;
for(auto&& s : fill1) for(auto&& s : fill1)
m[s] = fill_argument(p.get_parameter_shape(s), 1); m[s] = fill_argument(p.get_parameter_shape(s), 1);
fill_param_map(m, p, gpu); fill_param_map(m, p, use_gpu && gpu);
return m; return m;
} }
program compile()
{
auto p = l.load();
auto t = get_target(gpu);
if(quantize == q_fp16)
{
quantize_fp16(p);
}
else if(quantize == q_int8)
{
quantize_int8(p, t, {params(p, false)});
}
p.compile(t);
return p;
}
}; };
struct read : command<read> struct read : command<read>
......
...@@ -45,6 +45,22 @@ program::parameter_map create_param_map(const program& p, bool gpu) ...@@ -45,6 +45,22 @@ program::parameter_map create_param_map(const program& p, bool gpu)
return m; return m;
} }
target get_target(bool gpu)
{
if(gpu)
{
#ifdef HAVE_GPU
return gpu::target{};
#else
MIGRAPHX_THROW("Gpu not supported.");
#endif
}
else
{
return cpu::target{};
}
}
void compile_program(program& p, bool gpu) void compile_program(program& p, bool gpu)
{ {
if(gpu) if(gpu)
......
...@@ -9,6 +9,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -9,6 +9,7 @@ inline namespace MIGRAPHX_INLINE_NS {
program::parameter_map fill_param_map(program::parameter_map& m, const program& p, bool gpu); program::parameter_map fill_param_map(program::parameter_map& m, const program& p, bool gpu);
program::parameter_map create_param_map(const program& p, bool gpu = true); program::parameter_map create_param_map(const program& p, bool gpu = true);
target get_target(bool gpu);
void compile_program(program& p, bool gpu = true); void compile_program(program& p, bool gpu = true);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -38,7 +38,7 @@ capture_arguments(program& prog, T&& t, const std::vector<std::string>& ins_name ...@@ -38,7 +38,7 @@ capture_arguments(program& prog, T&& t, const std::vector<std::string>& ins_name
void quantize_int8(program& prog, void quantize_int8(program& prog,
const target& t, const target& t,
std::vector<program::parameter_map>& calibration, const std::vector<program::parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot", "convolution"}); const std::vector<std::string>& ins_names = {"dot", "convolution"});
void quantize_int8_impl(program& prog, void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params, const std::vector<std::pair<float, float>>& quant_params,
......
...@@ -414,7 +414,7 @@ void quantize_int8_impl(program& prog, ...@@ -414,7 +414,7 @@ void quantize_int8_impl(program& prog,
void quantize_int8(program& prog, void quantize_int8(program& prog,
const target& t, const target& t,
std::vector<program::parameter_map>& calibration, const std::vector<program::parameter_map>& calibration,
const std::vector<std::string>& ins_names) const std::vector<std::string>& ins_names)
{ {
// insert capture operator // insert capture operator
...@@ -433,8 +433,8 @@ void quantize_int8(program& prog, ...@@ -433,8 +433,8 @@ void quantize_int8(program& prog,
{ {
if(arg.count(x.first) > 0) if(arg.count(x.first) > 0)
{ {
assert(x.second == arg[x.first].get_shape()); assert(x.second == arg.at(x.first).get_shape());
m[x.first] = t.copy_to(arg[x.first]); m[x.first] = t.copy_to(arg.at(x.first));
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment