Commit 4e433399 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

implement the capture framework.

parent cc9efa4e
...@@ -18,7 +18,8 @@ namespace op { ...@@ -18,7 +18,8 @@ namespace op {
struct capture struct capture
{ {
std::function<void(std::vector<argument>)> f; std::size_t ins_index;
std::function<void(std::size_t ins_index, std::vector<argument>)> f;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
...@@ -31,11 +32,8 @@ struct capture ...@@ -31,11 +32,8 @@ struct capture
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; f(ins_index, args);
args.push_back(result); return args.front();
f(args);
return result;
} }
}; };
......
...@@ -15,6 +15,10 @@ struct program; ...@@ -15,6 +15,10 @@ struct program;
void quantize(program& prog, const std::vector<std::string>& ins_names); void quantize(program& prog, const std::vector<std::string>& ins_names);
void quantize(program& prog); void quantize(program& prog);
// insert the capture operator for the inputs of each operator to be quantized
// to int8
void capture_arguments(program& prog, const std::vector<std::string>& ins_names);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/op/convert.hpp> #include <migraphx/op/convert.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <utility> #include <utility>
...@@ -103,5 +104,63 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -103,5 +104,63 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
void quantize(program& prog) { quantize(prog, {"all"}); } void quantize(program& prog) { quantize(prog, {"all"}); }
std::vector<std::vector<argument> > ins_args;
void capture_args(std::size_t ins_index, std::vector<argument> args) {
if (ins_index = ins_args.size())
{
ins_args.push_back(std::vector<argument>{});
}
ins_args[ins_index].push_back(args.front());
return;
}
void calc_quant_params(std::vector<std::vector<argument>>&ins_arg, std::vector<std::pair<float, float>>& ins_params)
{
return;
}
// For the input of each input argument, we need to insert a
// capture operator to compute the scale and shift
void capture_arguments(program& prog, const std::vector<std::string>& ins_names)
{
// the int8 quantization only support dot and convolution
std::vector<std::string> op_names = {"dot", "convolution"};
if (!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) {
return std::find(op_names.begin(), op_names.end(), name) != op_names.end();
}))
{
MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported");
}
std::unordered_map<instruction_ref, instruction_ref> ins_map;
std::size_t index = 0;
for(auto ins : iterator_for(prog))
{
if (not contains(ins_names, ins->name()))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for (auto input : inputs)
{
instruction_ref new_ins{};
if (ins_map.count(input) > 0)
{
new_ins = ins_map[input];
}
else
{
new_ins = prog.insert_instruction(std::next(input), op::capture{index++, capture_args}, input);
ins_map[input] = new_ins;
}
new_args.push_back(new_ins);
}
instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args);
}
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment