"dev-scripts" did not exist on "e129f2fd81260e91453b4ee9ebd5dcdbf1b55dca"
Commit 41ed1924 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refine int8 quantization APIs

parent 318dbc15
...@@ -22,10 +22,17 @@ void quantize(program& prog); ...@@ -22,10 +22,17 @@ void quantize(program& prog);
std::size_t capture_arguments(program& prog, std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names, const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func); const std::function<void(std::size_t, std::vector<argument>)>& func);
std::shared_ptr<std::vector<std::pair<float, float>>> std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, const target& t, const std::vector<std::string>& ins_names); capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names = {"dot"});
template<class T>
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog, std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog,
const target& t); T&& t, const std::vector<std::string>& ins_names = {"dot"})
{
static_assert(std::is_same<std::remove_cv_t<std::remove_reference_t<T>>, target>{} && std::is_lvalue_reference<T>{}, "Dangling reference to target!");
return capture_arguments_impl(prog, t, ins_names);
}
void quantize_int8(program& prog, void quantize_int8(program& prog,
const target& t, const target& t,
......
...@@ -72,7 +72,6 @@ argument target_allocate(rank<0>, T& x, const shape&) ...@@ -72,7 +72,6 @@ argument target_allocate(rank<0>, T& x, const shape&)
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name); MIGRAPHX_THROW("Not computable: " + name);
return argument{};
} }
template <class T> template <class T>
...@@ -88,12 +87,9 @@ auto copy_to_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_to(ar ...@@ -88,12 +87,9 @@ auto copy_to_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_to(ar
} }
template <class T> template <class T>
argument copy_to_target(rank<0>, T& x, const argument&) argument copy_to_target(rank<0>, T&, const argument& arg)
{ {
std::string name = x.name(); return arg;
MIGRAPHX_THROW("Not computable: " + name);
return argument{};
} }
template <class T> template <class T>
...@@ -109,11 +105,9 @@ auto copy_from_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_fro ...@@ -109,11 +105,9 @@ auto copy_from_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_fro
} }
template <class T> template <class T>
argument copy_from_target(rank<0>, T& x, const argument&) argument copy_from_target(rank<0>, T&, const argument& arg)
{ {
std::string name = x.name(); return arg;
MIGRAPHX_THROW("Not computable: " + name);
return argument{};
} }
template <class T> template <class T>
......
...@@ -502,7 +502,7 @@ std::size_t capture_arguments(program& prog, ...@@ -502,7 +502,7 @@ std::size_t capture_arguments(program& prog,
} }
std::shared_ptr<std::vector<std::pair<float, float>>> std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, const target& t, const std::vector<std::string>& ins_names) capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names)
{ {
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params = std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>(); std::make_shared<std::vector<std::pair<float, float>>>();
...@@ -515,7 +515,7 @@ capture_arguments(program& prog, const target& t, const std::vector<std::string> ...@@ -515,7 +515,7 @@ capture_arguments(program& prog, const target& t, const std::vector<std::string>
// scale and shift is need for only int8 type, and we do not // scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0 // consider shift, so set shift to 0
std::vector<float> vec_val; std::vector<float> vec_val;
auto&& arg = t.copy_from(args.front()); argument arg = t.copy_from(args.front());
arg.visit([&](auto output) { vec_val.assign(output.begin(), output.end()); }); arg.visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
auto max_val = *std::max_element(vec_val.begin(), vec_val.end()); auto max_val = *std::max_element(vec_val.begin(), vec_val.end());
auto min_val = *std::min_element(vec_val.begin(), vec_val.end()); auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
...@@ -534,12 +534,5 @@ capture_arguments(program& prog, const target& t, const std::vector<std::string> ...@@ -534,12 +534,5 @@ capture_arguments(program& prog, const target& t, const std::vector<std::string>
return int8_quant_params; return int8_quant_params;
} }
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog,
const target& t)
{
std::vector<std::string> ins_names = {"dot", "convolution"};
return capture_arguments(prog, t, ins_names);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -17,7 +17,9 @@ struct target ...@@ -17,7 +17,9 @@ struct target
migraphx::context get_context() const { return context{}; } migraphx::context get_context() const { return context{}; }
argument copy_to(const argument& arg) const { return std::move(arg); } argument copy_to(const argument& arg) const { return std::move(arg); }
argument copy_from(const argument& arg) const { return std::move(arg); } argument copy_from(const argument& arg) const {
return arg;
}
argument allocate(const shape& s) const; argument allocate(const shape& s) const;
}; };
......
...@@ -2067,7 +2067,8 @@ TEST_CASE(op_capture) ...@@ -2067,7 +2067,8 @@ TEST_CASE(op_capture)
p.add_instruction(migraphx::op::dot{}, pa, ps); p.add_instruction(migraphx::op::dot{}, pa, ps);
migraphx::program capture_p = p; migraphx::program capture_p = p;
migraphx::capture_arguments(capture_p); migraphx::target t = migraphx::cpu::target{};
migraphx::capture_arguments(capture_p, t);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
capture_p.compile(migraphx::cpu::target{}); capture_p.compile(migraphx::cpu::target{});
......
...@@ -248,7 +248,8 @@ TEST_CASE(op_capture) ...@@ -248,7 +248,8 @@ TEST_CASE(op_capture)
{ {
auto p = create_program_float(); auto p = create_program_float();
auto op_capture_p = create_program_op(); auto op_capture_p = create_program_op();
migraphx::capture_arguments(p); migraphx::target t = migraphx::cpu::target{};
migraphx::capture_arguments(p, t);
EXPECT(p == op_capture_p); EXPECT(p == op_capture_p);
} }
} }
......
...@@ -72,7 +72,6 @@ argument target_allocate(rank<0>, T& x, const shape&) ...@@ -72,7 +72,6 @@ argument target_allocate(rank<0>, T& x, const shape&)
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name); MIGRAPHX_THROW("Not computable: " + name);
return argument{};
} }
template <class T> template <class T>
...@@ -88,12 +87,9 @@ auto copy_to_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_to(ar ...@@ -88,12 +87,9 @@ auto copy_to_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_to(ar
} }
template <class T> template <class T>
argument copy_to_target(rank<0>, T& x, const argument&) argument copy_to_target(rank<0>, T&, const argument& arg)
{ {
std::string name = x.name(); return arg;
MIGRAPHX_THROW("Not computable: " + name);
return argument{};
} }
template <class T> template <class T>
...@@ -109,11 +105,9 @@ auto copy_from_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_fro ...@@ -109,11 +105,9 @@ auto copy_from_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_fro
} }
template <class T> template <class T>
argument copy_from_target(rank<0>, T& x, const argument&) argument copy_from_target(rank<0>, T&, const argument& arg)
{ {
std::string name = x.name(); return arg;
MIGRAPHX_THROW("Not computable: " + name);
return argument{};
} }
template <class T> template <class T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment