Commit 41ed1924 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refine int8 quantization APIs

parent 318dbc15
......@@ -22,10 +22,17 @@ void quantize(program& prog);
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func);
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, const target& t, const std::vector<std::string>& ins_names);
capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names = {"dot"});
template<class T>
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog,
const target& t);
T&& t, const std::vector<std::string>& ins_names = {"dot"})
{
static_assert(std::is_same<std::remove_cv_t<std::remove_reference_t<T>>, target>{} && std::is_lvalue_reference<T>{}, "Dangling reference to target!");
return capture_arguments_impl(prog, t, ins_names);
}
void quantize_int8(program& prog,
const target& t,
......
......@@ -72,7 +72,6 @@ argument target_allocate(rank<0>, T& x, const shape&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
return argument{};
}
template <class T>
......@@ -88,12 +87,9 @@ auto copy_to_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_to(ar
}
template <class T>
argument copy_to_target(rank<0>, T& x, const argument&)
argument copy_to_target(rank<0>, T&, const argument& arg)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
return argument{};
return arg;
}
template <class T>
......@@ -109,11 +105,9 @@ auto copy_from_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_fro
}
template <class T>
argument copy_from_target(rank<0>, T& x, const argument&)
argument copy_from_target(rank<0>, T&, const argument& arg)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
return argument{};
return arg;
}
template <class T>
......
......@@ -502,7 +502,7 @@ std::size_t capture_arguments(program& prog,
}
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, const target& t, const std::vector<std::string>& ins_names)
capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names)
{
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>();
......@@ -515,7 +515,7 @@ capture_arguments(program& prog, const target& t, const std::vector<std::string>
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
std::vector<float> vec_val;
auto&& arg = t.copy_from(args.front());
argument arg = t.copy_from(args.front());
arg.visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
auto max_val = *std::max_element(vec_val.begin(), vec_val.end());
auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
......@@ -534,12 +534,5 @@ capture_arguments(program& prog, const target& t, const std::vector<std::string>
return int8_quant_params;
}
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog,
const target& t)
{
std::vector<std::string> ins_names = {"dot", "convolution"};
return capture_arguments(prog, t, ins_names);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -17,7 +17,9 @@ struct target
migraphx::context get_context() const { return context{}; }
argument copy_to(const argument& arg) const { return std::move(arg); }
argument copy_from(const argument& arg) const { return std::move(arg); }
argument copy_from(const argument& arg) const {
return arg;
}
argument allocate(const shape& s) const;
};
......
......@@ -2067,7 +2067,8 @@ TEST_CASE(op_capture)
p.add_instruction(migraphx::op::dot{}, pa, ps);
migraphx::program capture_p = p;
migraphx::capture_arguments(capture_p);
migraphx::target t = migraphx::cpu::target{};
migraphx::capture_arguments(capture_p, t);
p.compile(migraphx::cpu::target{});
capture_p.compile(migraphx::cpu::target{});
......
......@@ -248,7 +248,8 @@ TEST_CASE(op_capture)
{
auto p = create_program_float();
auto op_capture_p = create_program_op();
migraphx::capture_arguments(p);
migraphx::target t = migraphx::cpu::target{};
migraphx::capture_arguments(p, t);
EXPECT(p == op_capture_p);
}
}
......
......@@ -72,7 +72,6 @@ argument target_allocate(rank<0>, T& x, const shape&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
return argument{};
}
template <class T>
......@@ -88,12 +87,9 @@ auto copy_to_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_to(ar
}
template <class T>
argument copy_to_target(rank<0>, T& x, const argument&)
argument copy_to_target(rank<0>, T&, const argument& arg)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
return argument{};
return arg;
}
template <class T>
......@@ -109,11 +105,9 @@ auto copy_from_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_fro
}
template <class T>
argument copy_from_target(rank<0>, T& x, const argument&)
argument copy_from_target(rank<0>, T&, const argument& arg)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
return argument{};
return arg;
}
template <class T>
......@@ -124,9 +118,9 @@ argument copy_from_target(T& x, const argument& arg)
<%
interface('target',
virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True),
virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True),
virtual('copy_to',
returns = 'argument',
input = 'const argument&',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment