Commit 7e297b13 authored by Paul's avatar Paul
Browse files

Merge

parents 86ea5e91 aa7ff911
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_thresholdedrelu : op_parser<parse_thresholdedrelu>
{
std::vector<op_desc> operators() const { return {{"ThresholdedRelu"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 1.0;
if(contains(info.attributes, "alpha"))
alpha = parser.parse_value(info.attributes.at("alpha")).at<float>();
auto x_shape = args[0]->get_shape();
auto lit_zero = info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {0}});
auto lit_alpha =
info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {alpha}});
auto mb_zero = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_zero);
auto mb_alpha = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_alpha);
auto condition = info.add_instruction(migraphx::make_op("greater"), args[0], mb_alpha);
return info.add_instruction(migraphx::make_op("where"), condition, args[0], mb_zero);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_topk : op_parser<parse_topk>
{
std::vector<op_desc> operators() const { return {{"TopK"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
int64_t k = 0;
if(args.size() == 2)
{
auto arg_k = args.at(1)->eval();
check_arg_empty(arg_k, "PARSE_TopK: k input must be constant");
k = arg_k.at<int>();
}
else if(contains(info.attributes, "k"))
{
k = info.attributes.at("k").i();
}
bool largest = true;
if(contains(info.attributes, "largest"))
{
largest = static_cast<bool>(info.attributes.at("largest").i());
}
int64_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int>();
}
auto topk_ret = info.add_instruction(
make_op("topk", {{"k", k}, {"axis", axis}, {"largest", largest}}), args.at(0));
auto ret_val = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), topk_ret);
auto ret_ind = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), topk_ret);
return {ret_val, ret_ind};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -21,7 +22,21 @@ struct parse_transpose : op_parser<parse_transpose> ...@@ -21,7 +22,21 @@ struct parse_transpose : op_parser<parse_transpose>
auto&& perm_vals = info.attributes["perm"].ints(); auto&& perm_vals = info.attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end()); perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
} }
return info.add_instruction(make_op("transpose", {{"dims", perm}}), args.front());
// if perm is empty, use the default value
auto n_dim = args.front()->get_shape().lens().size();
if(perm.empty())
{
perm.resize(n_dim);
std::iota(perm.rbegin(), perm.rend(), 0);
}
if(perm.size() != n_dim)
{
MIGRAPHX_THROW("PARSE_TRANSPOSE: perm and input have diffferent number of dims!");
}
return info.add_instruction(make_op("transpose", {{"permutation", perm}}), args.front());
} }
}; };
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_upsample : op_parser<parse_upsample>
{
std::vector<op_desc> operators() const { return {{"Upsample"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
if(contains(info.attributes, "mode"))
{
auto mode = info.attributes.at("mode").s();
if(mode != "nearest")
{
MIGRAPHX_THROW("PARSE_UPSAMPLE: only nearest mode is supported!");
}
}
auto arg_scale = args[1]->eval();
check_arg_empty(arg_scale, "PARSE_UPSAMPLE: only constant scale is supported!");
std::vector<float> vec_scale;
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
auto in_s = args[0]->get_shape();
auto in_lens = in_s.lens();
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_UPSAMPLE: ranks of input and scale are different!");
}
std::vector<std::size_t> out_lens(in_lens.size());
std::transform(in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
std::vector<float> idx_scale(in_lens.size());
std::transform(
out_lens.begin(),
out_lens.end(),
in_lens.begin(),
idx_scale.begin(),
[](auto od, auto id) { return (od == id) ? 1.0f : (id - 1.0f) / (od - 1.0f); });
shape out_s{in_s.type(), out_lens};
std::vector<int> ind(out_s.elements());
// map out_idx to in_idx
shape_for_each(out_s, [&](auto idx) {
auto in_idx = idx;
std::transform(idx.begin(),
idx.end(),
idx_scale.begin(),
in_idx.begin(),
// nearest mode
[](auto index, auto scale) {
return static_cast<std::size_t>(std::round(index * scale));
});
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx));
});
// reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
shape ind_s{shape::int32_type, out_lens};
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
auto ins_ind = info.add_literal(literal(ind_s, ind));
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -17,45 +17,28 @@ struct parse_where : op_parser<parse_where> ...@@ -17,45 +17,28 @@ struct parse_where : op_parser<parse_where>
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto cond = auto lens =
info.add_instruction(make_op("convert", {{"target_type", shape::int32_type}}), args[0]); compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
auto lens = compute_broadcasted_lens(cond->get_shape().lens(), args[1]->get_shape().lens()); lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens()); if(args[0]->get_shape().lens() != lens)
if(cond->get_shape().lens() != lens)
{ {
cond = info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), cond); args[0] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
} }
if(args[1]->get_shape().lens() != lens) if(args[1]->get_shape().lens() != lens)
{ {
args[1] = args[1] =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[1]); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]);
} }
if(args[2]->get_shape().lens() != lens) if(args[2]->get_shape().lens() != lens)
{ {
args[2] = args[2] =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[2]); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
} }
// compute index return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
auto elem_num = args[1]->get_shape().elements();
// concatenation of input data
auto concat_data = info.add_instruction(make_op("concat", {{"axis", 0}}), args[2], args[1]);
std::vector<int64_t> dims = {static_cast<int64_t>(2 * elem_num)};
auto rsp_data = info.add_instruction(make_op("reshape", {{"dims", dims}}), concat_data);
std::vector<int> ind(elem_num);
std::iota(ind.begin(), ind.end(), 0);
shape ind_s{shape::int32_type, lens};
auto l_ind = info.add_literal(literal(ind_s, ind));
std::vector<int> offset(elem_num, elem_num);
auto l_offset = info.add_literal(literal({shape::int32_type, lens}, offset));
auto ins_offset = info.add_instruction(make_op("mul"), l_offset, cond);
auto ins_ind = info.add_instruction(make_op("add"), ins_offset, l_ind);
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp_data, ins_ind);
} }
}; };
......
//
// Supporting functions for enum values used in operator parameters.
// These values are declared as "enum class" and should include << streaming operators
// to be able to write their values in human-readable format so users can
// save and edit model files.
//
#include <sstream>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
std::ostream& operator<<(std::ostream& os, pooling_mode v)
{
// the strings for the enum are the same as the values used for onnx parsing
// but this enum is not onnx-specific: strings must be converted when parsing tf
static const std::vector<std::string> pooling_mode_str = {"average", "max", "lpnorm"};
os << pooling_mode_str[static_cast<std::underlying_type<pooling_mode>::type>(v)];
return os;
}
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
static const std::vector<std::string> rnn_direction_str = {
"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(module& p) const void memory_coloring::apply(module& m) const
{ {
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{})) if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{ {
memory_coloring_impl opt(&p, allocation_op, verify); memory_coloring_impl opt(&m, allocation_op, verify);
opt.run(); opt.run();
} }
} }
......
...@@ -40,7 +40,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval) ...@@ -40,7 +40,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
std::size_t size = s.bytes(); std::size_t size = s.bytes();
if(size == 0) if(size == 0)
return false; return false;
std::size_t element_size = size / s.elements(); std::size_t element_size = (s.elements() == 0 ? 4 : (size / s.elements()));
live_range& segment = interval->segment; live_range& segment = interval->segment;
int vn = segment.vn; int vn = segment.vn;
std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue; std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue;
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PASSES);
void validate_pass(module& mod, const pass& p, tracer trace) void validate_pass(module& mod, const pass& p, tracer trace)
{ {
(void)mod; (void)mod;
...@@ -32,14 +34,6 @@ void validate_pass(module& mod, const pass& p, tracer trace) ...@@ -32,14 +34,6 @@ void validate_pass(module& mod, const pass& p, tracer trace)
trace(); trace();
#endif #endif
} }
void run_pass(module& mod, const pass& p, tracer trace)
{
trace("Module: ", mod.name(), ", Pass: ", p.name());
assert(mod.validate() == mod.end());
p.apply(mod);
trace(mod);
validate_pass(mod, p, trace);
}
void run_pass(program& prog, const pass& p, tracer trace) void run_pass(program& prog, const pass& p, tracer trace)
{ {
trace("Pass: ", p.name()); trace("Pass: ", p.name());
...@@ -47,22 +41,69 @@ void run_pass(program& prog, const pass& p, tracer trace) ...@@ -47,22 +41,69 @@ void run_pass(program& prog, const pass& p, tracer trace)
trace(prog); trace(prog);
} }
struct module_pm : module_pass_manager
{
module* mod;
program* prog;
tracer* t;
module_pm(module* pmod = nullptr, program* pprog = nullptr, tracer* pt = nullptr)
: mod(pmod), prog(pprog), t(pt)
{
}
template <class... Ts>
void trace(Ts&&... xs) const
{
assert(t);
(*t)(xs...);
}
virtual module& get_module() override
{
assert(mod);
return *mod;
}
virtual module* create_module(const std::string& name) override
{
assert(prog);
return prog->create_module(name);
}
virtual void run_pass(const pass& p) override
{
assert(mod);
trace("Module: ", mod->name(), ", Pass: ", p.name());
assert(mod->validate() == mod->end());
p.apply(*this);
trace(*mod);
validate_pass(*mod, p, *t);
}
};
module& get_module(module_pass_manager& mpm) { return mpm.get_module(); }
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace) void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
{ {
if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout};
for(const auto& p : passes) for(const auto& p : passes)
{ {
run_pass(mod, p, trace); module_pm{&mod, nullptr, &trace}.run_pass(p);
} }
} }
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace) void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{ {
if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout};
for(const auto& p : passes) for(const auto& p : passes)
{ {
auto mods = prog.get_modules(); auto mods = prog.get_modules();
for(const auto& mod : reverse(mods)) for(const auto& mod : reverse(mods))
{ {
run_pass(*mod, p, trace); if(mod->bypass())
continue;
module_pm{mod, &prog, &trace}.run_pass(p);
} }
run_pass(prog, p, trace); run_pass(prog, p, trace);
} }
......
...@@ -10,6 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,6 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void preallocate_param::apply(module& m) const void preallocate_param::apply(module& m) const
{ {
auto last = std::prev(m.end());
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "@param") if(ins->name() != "@param")
...@@ -19,7 +20,9 @@ void preallocate_param::apply(module& m) const ...@@ -19,7 +20,9 @@ void preallocate_param::apply(module& m) const
std::string id = m.name() + ":" + param; std::string id = m.name() + ":" + param;
auto r = m.insert_instruction(ins, model.preallocate(ins->get_shape(), id)); auto r = m.insert_instruction(ins, model.preallocate(ins->get_shape(), id));
m.replace_instruction(ins, r); m.replace_instruction(ins, r);
m.move_instruction(ins, m.end());
} }
m.remove_instructions(std::next(last), m.end());
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -20,7 +20,6 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out ...@@ -20,7 +20,6 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out
int ec = 0; int ec = 0;
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{})) if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
std::cout << cmd << std::endl; std::cout << cmd << std::endl;
std::array<char, 128> buffer;
auto closer = [&](FILE* stream) { auto closer = [&](FILE* stream) {
auto status = pclose(stream); auto status = pclose(stream);
ec = WIFEXITED(status) ? 0 : WEXITSTATUS(status); // NOLINT ec = WIFEXITED(status) ? 0 : WEXITSTATUS(status); // NOLINT
...@@ -30,6 +29,7 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out ...@@ -30,6 +29,7 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT
if(!pipe) if(!pipe)
MIGRAPHX_THROW("popen() failed: " + cmd); MIGRAPHX_THROW("popen() failed: " + cmd);
std::array<char, 128> buffer;
while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr)
std_out(buffer.data()); std_out(buffer.data());
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <migraphx/output_iterator.hpp> #include <migraphx/output_iterator.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
...@@ -179,19 +180,78 @@ void program::finalize() ...@@ -179,19 +180,78 @@ void program::finalize()
mm->finalize(this->impl->ctx); mm->finalize(this->impl->ctx);
} }
template <class T>
std::string classify(T x)
{
switch(std::fpclassify(x))
{
case FP_INFINITE: return "inf";
case FP_NAN: return "nan";
case FP_NORMAL: return "normal";
case FP_SUBNORMAL: return "subnormal";
case FP_ZERO: return "zero";
default: return "unknown";
}
}
std::unordered_set<std::string> classify_argument(const argument& a)
{
std::unordered_set<std::string> result;
a.visit(
[&](auto t) {
for(const auto& x : t)
result.insert(classify(x));
},
[&](const auto& xs) {
for(const auto& x : xs)
{
auto r = classify_argument(x);
result.insert(r.begin(), r.end());
}
});
return result;
}
void preview_argument(std::ostream& os, const argument& a)
{
a.visit(
[&](auto t) {
if(t.size() <= 10)
{
os << t;
}
else
{
os << to_string_range(t.begin(), t.begin() + 5);
os << ", ..., ";
os << to_string_range(t.end() - 5, t.end());
}
},
[&](const auto& xs) {
for(const auto& x : xs)
{
os << '{';
preview_argument(os, x);
os << '}';
}
});
}
template <class F> template <class F>
std::vector<argument> generic_eval(const module* mod, std::vector<argument> generic_eval(const module* mod,
context& ctx, context& ctx,
std::unordered_map<std::string, argument> params, std::unordered_map<std::string, argument> params,
std::unordered_map<instruction_ref, argument> results, std::unordered_map<instruction_ref, argument> results,
F trace) F make_trace)
{ {
assert(mod->validate() == mod->end()); assert(mod->validate() == mod->end());
results.reserve(mod->size() * 2); results.reserve(mod->size() * 2);
std::vector<argument> values; std::vector<argument> values;
values.reserve(16); values.reserve(16);
auto trace = make_trace(mod);
for(auto ins : iterator_for(*mod)) for(auto ins : iterator_for(*mod))
{ {
assert(results.find(ins) == results.end());
const auto& name = ins->name(); const auto& name = ins->name();
if(name == "@literal") if(name == "@literal")
{ {
...@@ -240,7 +300,8 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -240,7 +300,8 @@ std::vector<argument> generic_eval(const module* mod,
const auto& mod_args = ins->module_inputs(); const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod, auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) { const std::unordered_map<std::string, argument>& inputs) {
return generic_eval(smod, ctx, inputs, results, trace); auto ssctx = ctx;
return generic_eval(smod, ssctx, inputs, results, make_trace);
}; };
results.emplace(ins, trace(ins, [&] { results.emplace(ins, trace(ins, [&] {
...@@ -249,6 +310,7 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -249,6 +310,7 @@ std::vector<argument> generic_eval(const module* mod,
})); }));
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
assert(results.at(ins).get_shape() == ins->get_shape());
} }
return {results.at(std::prev(mod->end()))}; return {results.at(std::prev(mod->end()))};
} }
...@@ -257,50 +319,90 @@ template <class F> ...@@ -257,50 +319,90 @@ template <class F>
std::vector<argument> generic_eval(const program& p, std::vector<argument> generic_eval(const program& p,
context& ctx, context& ctx,
std::unordered_map<std::string, argument> params, std::unordered_map<std::string, argument> params,
F trace) F make_trace)
{ {
const module* mm = p.get_main_module(); const module* mm = p.get_main_module();
return generic_eval(mm, ctx, params, {}, trace); return generic_eval(mm, ctx, params, {}, make_trace);
} }
std::vector<argument> program::eval(parameter_map params) const std::vector<argument> program::eval(parameter_map params) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->ctx;
#ifndef NDEBUG #ifndef NDEBUG
auto sctx = ctx; auto with_check_context = [&](auto f) {
auto check_context = [&](auto f) { return [=, &ctx](auto&&) {
assert(is_shared(ctx, sctx)); auto sctx = std::make_shared<context>(ctx);
auto x = f(); auto check_context = [=, &ctx](auto g) {
sctx = ctx; assert(is_shared(ctx, *sctx));
return x; auto x = g();
*sctx = ctx;
return x;
};
return [=](auto&&... xs) { return f(xs..., check_context); };
};
}; };
#else #else
auto check_context = [](auto f) { return f(); }; auto with_check_context = [](auto f) {
return [=](auto&&) {
return [=](auto&&... xs) { return f(xs..., [](auto g) { return g(); }); };
};
};
#endif #endif
auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{}); auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
if(trace_level > 0) if(trace_level > 0)
{ {
return generic_eval(*this, ctx, std::move(params), [&](auto& ins, auto f) { std::unordered_map<instruction_ref, std::string> ins_out;
ctx.finish(); // get instruction names
std::cout << "Run instruction: "; this->print([&](auto x, auto ins_names) {
this->debug_print(ins); std::stringstream ss;
timer t{}; instruction::print(ss, x, ins_names);
auto result = check_context(f); ins_out[x] = ss.str();
double t1 = t.record<milliseconds>();
ctx.finish();
double t2 = t.record<milliseconds>();
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
if(trace_level > 1 and ins->name().front() != '@' and ins->name() != "load")
std::cout << "Output: " << result << std::endl;
return result;
}); });
return generic_eval(*this,
ctx,
std::move(params),
with_check_context([&](auto& ins, auto f, auto&& check_context) {
ctx.finish();
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
timer t{};
auto result = check_context(f);
double t1 = t.record<milliseconds>();
ctx.finish();
double t2 = t.record<milliseconds>();
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
if(trace_level > 1 and ins->name().front() != '@' and
ins->name() != "load" and not result.empty())
{
target tgt = make_target(this->impl->target_name);
auto buffer = tgt.copy_from(result);
if(trace_level == 2)
{
std::cout << "Output has "
<< to_string_range(classify_argument(buffer))
<< std::endl;
std::cout << "Output: ";
preview_argument(std::cout, buffer);
std::cout << std::endl;
}
else
{
std::cout << "Output: " << buffer << std::endl;
}
}
return result;
}));
} }
else else
{ {
return generic_eval( return generic_eval(*this,
*this, ctx, std::move(params), [&](auto&, auto f) { return check_context(f); }); ctx,
std::move(params),
with_check_context([&](auto&, auto f, auto&& check_context) {
return check_context(f);
}));
} }
} }
...@@ -483,7 +585,28 @@ std::string perf_group(const operation& op) ...@@ -483,7 +585,28 @@ std::string perf_group(const operation& op)
return op.name(); return op.name();
} }
void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) const void program::mark(const parameter_map& params, marker&& m)
{
auto& ctx = this->impl->ctx;
// Run once by itself
eval(params);
ctx.finish();
// Start marking
m.mark_start(*this);
generic_eval(*this, ctx, params, always([&](auto ins, auto f) {
argument result;
m.mark_start(ins);
result = f();
m.mark_stop(ins);
return result;
}));
m.mark_stop(*this);
}
void program::perf_report(std::ostream& os,
std::size_t n,
parameter_map params,
std::size_t batch) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->ctx;
// Run once by itself // Run once by itself
...@@ -502,21 +625,22 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -502,21 +625,22 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
std::sort(total_vec.begin(), total_vec.end()); std::sort(total_vec.begin(), total_vec.end());
std::unordered_map<instruction_ref, std::vector<double>> ins_vec; std::unordered_map<instruction_ref, std::vector<double>> ins_vec;
// Fill the map // Fill the map
generic_eval(*this, ctx, params, [&](auto ins, auto) { generic_eval(*this, ctx, params, always([&](auto ins, auto) {
ins_vec[ins].reserve(n); ins_vec[ins].reserve(n);
return argument{}; return argument{ins->get_shape(), nullptr};
}); }));
// Run and time each instruction // Run and time each instruction
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
generic_eval(*this, ctx, params, [&](auto ins, auto f) { generic_eval(*this, ctx, params, always([&](auto ins, auto f) {
argument result; argument result;
ins_vec[ins].push_back(time<milliseconds>([&] { ins_vec[ins].push_back(time<milliseconds>([&] {
result = f(); result = f();
ctx.finish(); ctx.finish();
})); }));
return result; return result;
}); }));
} }
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
std::sort(p.second.begin(), p.second.end()); std::sort(p.second.begin(), p.second.end());
...@@ -575,7 +699,8 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -575,7 +699,8 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
os << std::endl; os << std::endl;
os << "Rate: " << rate << "/sec" << std::endl; os << "Batch size: " << batch << std::endl;
os << "Rate: " << rate * batch << "/sec" << std::endl;
os << "Total time: " << total_time << "ms" << std::endl; os << "Total time: " << total_time << "ms" << std::endl;
os << "Total instructions time: " << total_instruction_time << "ms" << std::endl; os << "Total instructions time: " << total_instruction_time << "ms" << std::endl;
os << "Overhead time: " << overhead_time << "ms" os << "Overhead time: " << overhead_time << "ms"
...@@ -624,6 +749,14 @@ void program::print( ...@@ -624,6 +749,14 @@ void program::print(
} }
} }
void program::print(
const std::function<void(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>)>& print_func) const
{
std::unordered_map<instruction_ref, std::string> names;
this->print(names, print_func);
}
void program::print_graph(std::ostream& os, bool brief) const void program::print_graph(std::ostream& os, bool brief) const
{ {
const auto* mm = this->get_main_module(); const auto* mm = this->get_main_module();
...@@ -645,7 +778,9 @@ void program::print_cpp(std::ostream& os) const ...@@ -645,7 +778,9 @@ void program::print_cpp(std::ostream& os) const
void program::dry_run(std::unordered_map<std::string, argument> params) const void program::dry_run(std::unordered_map<std::string, argument> params) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->ctx;
generic_eval(*this, ctx, std::move(params), [](auto&&...) { return argument{}; }); generic_eval(*this, ctx, std::move(params), always([](auto ins, auto&&...) {
return argument{ins->get_shape(), nullptr};
}));
} }
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
...@@ -689,11 +824,12 @@ void generic_get_unused_modules(Map& m, const std::vector<T*>& mods, OutputItera ...@@ -689,11 +824,12 @@ void generic_get_unused_modules(Map& m, const std::vector<T*>& mods, OutputItera
std::transform(mods.begin(), mods.end(), std::inserter(used, used.end()), [](auto&& mod) { std::transform(mods.begin(), mods.end(), std::inserter(used, used.end()), [](auto&& mod) {
return mod->name(); return mod->name();
}); });
transform_if(m.begin(), transform_if(
m.end(), m.begin(),
out, m.end(),
[&](auto&& pp) { return not contains(used, pp.first); }, out,
[](auto&& pp) { return &pp.second; }); [&](auto&& pp) { return not contains(used, pp.first); },
[](auto&& pp) { return &pp.second; });
} }
std::vector<const module*> program::get_modules() const std::vector<const module*> program::get_modules() const
...@@ -745,6 +881,22 @@ void program::remove_module(const std::string& name) ...@@ -745,6 +881,22 @@ void program::remove_module(const std::string& name)
impl->modules.at(name).end(), impl->modules.at(name).end(),
[&](auto&& ins) { return references_instruction(impl->modules, ins, name); }) && [&](auto&& ins) { return references_instruction(impl->modules, ins, name); }) &&
"Instruction referenced in another module"); "Instruction referenced in another module");
// if an instruction has an input out side of the current module, need to remove
// the instruction from its input's outputs
auto& mod = impl->modules.at(name);
for(auto ins : iterator_for(mod))
{
auto inputs = ins->inputs();
for(auto in : inputs)
{
if(not mod.has_instruction(in))
{
in->remove_output(ins);
}
}
}
impl->modules.erase(name); impl->modules.erase(name);
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <unordered_set> #include <unordered_set>
namespace migraphx { namespace migraphx {
...@@ -20,33 +21,42 @@ bool skip_propogate(instruction_ref ins) ...@@ -20,33 +21,42 @@ bool skip_propogate(instruction_ref ins)
return false; return false;
} }
void propagate_constant::apply(module& p) const bool is_const(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); }
void propagate_constant::apply(module& m) const
{ {
for(auto i : iterator_for(p)) std::unordered_set<instruction_ref> const_instrs;
auto last = std::prev(m.end());
// Find instructions that can be evaluated to a literal
for(auto i : iterator_for(m))
{ {
if(i->name() != "@literal") if(is_const(i) and i != last)
continue; continue;
if(i->outputs().empty())
continue; std::copy_if(
fix([&](auto self, auto ins) { i->inputs().begin(),
std::unordered_set<instruction_ref> children(ins->outputs().begin(), i->inputs().end(),
ins->outputs().end()); std::inserter(const_instrs, const_instrs.begin()),
for(auto child : children) [&](const instruction_ref ins) { return is_const(ins) and ins->name() != "@literal"; });
{ }
if(child->name() == "@literal" or skip_propogate(child))
{ // Compute literals in parallel
self(child); std::vector<instruction_ref> const_instrs_vec{const_instrs.begin(), const_instrs.end()};
continue; std::vector<argument> literals(const_instrs_vec.size());
} par_for(const_instrs_vec.size(), 1, [&](const auto i) {
auto r = child->eval(); literals[i] = const_instrs_vec[i]->eval();
if(not r.empty()) });
{
assert(r.get_shape() == child->get_shape()); // Replace instructions in m
auto l = p.add_literal(r.get_shape(), r.data()); for(size_t i = 0; i < const_instrs_vec.size(); i++)
self(p.replace_instruction(child, l)); {
} if(not literals[i].empty())
} {
})(i); assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape());
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instrs_vec[i], l);
}
} }
} }
......
...@@ -3,8 +3,11 @@ ...@@ -3,8 +3,11 @@
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <pybind11/numpy.h> #include <pybind11/numpy.h>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
...@@ -95,7 +98,6 @@ migraphx::value to_value(py::kwargs kwargs) ...@@ -95,7 +98,6 @@ migraphx::value to_value(py::kwargs kwargs)
auto&& val = arg.second; auto&& val = arg.second;
visit_py(val, [&](auto py_val) { v[key] = py_val; }); visit_py(val, [&](auto py_val) { v[key] = py_val; });
} }
return v; return v;
} }
} // namespace migraphx } // namespace migraphx
...@@ -211,12 +213,21 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -211,12 +213,21 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE(migraphx, m) MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{ {
py::class_<migraphx::shape>(m, "shape") py::class_<migraphx::shape>(m, "shape")
.def(py::init<>()) .def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", "float"));
auto lens = v.get<std::size_t>("lens", {1});
if(v.contains("strides"))
return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>());
else
return migraphx::shape(t, lens);
}))
.def("type", &migraphx::shape::type) .def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens) .def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides) .def("strides", &migraphx::shape::strides)
.def("elements", &migraphx::shape::elements) .def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes) .def("bytes", &migraphx::shape::bytes)
.def("type_string", &migraphx::shape::type_string)
.def("type_size", &migraphx::shape::type_size) .def("type_size", &migraphx::shape::type_size)
.def("packed", &migraphx::shape::packed) .def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed) .def("transposed", &migraphx::shape::transposed)
...@@ -247,13 +258,46 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -247,13 +258,46 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target"); py::class_<migraphx::target>(m, "target");
py::class_<migraphx::module>(m, "module") py::class_<migraphx::instruction_ref>(m, "instruction_ref");
py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; }) .def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
.def("__eq__", std::equal_to<migraphx::module>{}) .def(
.def("__ne__", std::not_equal_to<migraphx::module>{}) "add_instruction",
[](migraphx::module& mm,
const migraphx::operation& op,
std::vector<migraphx::instruction_ref>& args,
std::vector<migraphx::module*>& mod_args) {
return mm.add_instruction(op, args, mod_args);
},
py::arg("op"),
py::arg("args"),
py::arg("mod_args") = std::vector<migraphx::module*>{})
.def(
"add_literal",
[](migraphx::module& mm, py::buffer data) {
py::buffer_info info = data.request();
auto literal_shape = to_shape(info);
return mm.add_literal(literal_shape, reinterpret_cast<char*>(info.ptr));
},
py::arg("data"))
.def(
"add_parameter",
[](migraphx::module& mm, const std::string& name, const migraphx::shape shape) {
return mm.add_parameter(name, shape);
},
py::arg("name"),
py::arg("shape"))
.def(
"add_return",
[](migraphx::module& mm, std::vector<migraphx::instruction_ref>& args) {
return mm.add_return(args);
},
py::arg("args"))
.def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); }); .def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); });
py::class_<migraphx::program>(m, "program") py::class_<migraphx::program>(m, "program")
.def(py::init([]() { return migraphx::program(); }))
.def("get_parameter_names", &migraphx::program::get_parameter_names) .def("get_parameter_names", &migraphx::program::get_parameter_names)
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes) .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_output_shapes", &migraphx::program::get_output_shapes) .def("get_output_shapes", &migraphx::program::get_output_shapes)
...@@ -268,11 +312,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -268,11 +312,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("t"), py::arg("t"),
py::arg("offload_copy") = true, py::arg("offload_copy") = true,
py::arg("fast_math") = true) py::arg("fast_math") = true)
.def("get_main_module", .def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); })
[](migraphx::program& p) { .def(
auto* mm = p.get_main_module(); "create_module",
return *mm; [](migraphx::program& p, const std::string& name) { return p.create_module(name); },
}) py::arg("name"))
.def("run", .def("run",
[](migraphx::program& p, py::dict params) { [](migraphx::program& p, py::dict params) {
migraphx::parameter_map pm; migraphx::parameter_map pm;
...@@ -303,86 +347,94 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -303,86 +347,94 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("name", &migraphx::operation::name); .def("name", &migraphx::operation::name);
m.def("parse_tf", m.def(
[](const std::string& filename, "parse_tf",
bool is_nhwc, [](const std::string& filename,
unsigned int batch_size, bool is_nhwc,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, unsigned int batch_size,
std::vector<std::string> output_names) { std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
return migraphx::parse_tf( std::vector<std::string> output_names) {
filename, return migraphx::parse_tf(
migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names}); filename, migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names});
}, },
"Parse tf protobuf (default format is nhwc)", "Parse tf protobuf (default format is nhwc)",
py::arg("filename"), py::arg("filename"),
py::arg("is_nhwc") = true, py::arg("is_nhwc") = true,
py::arg("batch_size") = 1, py::arg("batch_size") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(), py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("output_names") = std::vector<std::string>()); py::arg("output_names") = std::vector<std::string>());
m.def("parse_onnx", m.def(
[](const std::string& filename, "parse_onnx",
unsigned int default_dim_value, [](const std::string& filename,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, unsigned int default_dim_value,
bool skip_unknown_operators, std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool print_program_on_error) { bool skip_unknown_operators,
migraphx::onnx_options options; bool print_program_on_error,
options.default_dim_value = default_dim_value; int64_t max_loop_iterations) {
options.map_input_dims = map_input_dims; migraphx::onnx_options options;
options.skip_unknown_operators = skip_unknown_operators; options.default_dim_value = default_dim_value;
options.print_program_on_error = print_program_on_error; options.map_input_dims = map_input_dims;
return migraphx::parse_onnx(filename, options); options.skip_unknown_operators = skip_unknown_operators;
}, options.print_program_on_error = print_program_on_error;
"Parse onnx file", options.max_loop_iterations = max_loop_iterations;
py::arg("filename"), return migraphx::parse_onnx(filename, options);
py::arg("default_dim_value") = 1, },
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(), "Parse onnx file",
py::arg("skip_unknown_operators") = false, py::arg("filename"),
py::arg("print_program_on_error") = false); py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
m.def("parse_onnx_buffer", py::arg("skip_unknown_operators") = false,
[](const std::string& onnx_buffer, py::arg("print_program_on_error") = false,
unsigned int default_dim_value, py::arg("max_loop_iterations") = 10);
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators, m.def(
bool print_program_on_error) { "parse_onnx_buffer",
migraphx::onnx_options options; [](const std::string& onnx_buffer,
options.default_dim_value = default_dim_value; unsigned int default_dim_value,
options.map_input_dims = map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
options.skip_unknown_operators = skip_unknown_operators; bool skip_unknown_operators,
options.print_program_on_error = print_program_on_error; bool print_program_on_error) {
return migraphx::parse_onnx_buffer(onnx_buffer, options); migraphx::onnx_options options;
}, options.default_dim_value = default_dim_value;
"Parse onnx file", options.map_input_dims = map_input_dims;
py::arg("filename"), options.skip_unknown_operators = skip_unknown_operators;
py::arg("default_dim_value") = 1, options.print_program_on_error = print_program_on_error;
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(), return migraphx::parse_onnx_buffer(onnx_buffer, options);
py::arg("skip_unknown_operators") = false, },
py::arg("print_program_on_error") = false); "Parse onnx file",
py::arg("filename"),
m.def("load", py::arg("default_dim_value") = 1,
[](const std::string& name, const std::string& format) { py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
migraphx::file_options options; py::arg("skip_unknown_operators") = false,
options.format = format; py::arg("print_program_on_error") = false);
return migraphx::load(name, options);
}, m.def(
"Load MIGraphX program", "load",
py::arg("filename"), [](const std::string& name, const std::string& format) {
py::arg("format") = "msgpack"); migraphx::file_options options;
options.format = format;
m.def("save", return migraphx::load(name, options);
[](const migraphx::program& p, const std::string& name, const std::string& format) { },
migraphx::file_options options; "Load MIGraphX program",
options.format = format; py::arg("filename"),
return migraphx::save(p, name, options); py::arg("format") = "msgpack");
},
"Save MIGraphX program", m.def(
py::arg("p"), "save",
py::arg("filename"), [](const migraphx::program& p, const std::string& name, const std::string& format) {
py::arg("format") = "msgpack"); migraphx::file_options options;
options.format = format;
return migraphx::save(p, name, options);
},
"Save MIGraphX program",
py::arg("p"),
py::arg("filename"),
py::arg("format") = "msgpack");
m.def("get_target", &migraphx::make_target); m.def("get_target", &migraphx::make_target);
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0); m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("fill_argument", &migraphx::fill_argument, py::arg("s"), py::arg("value"));
m.def("quantize_fp16", m.def("quantize_fp16",
&migraphx::quantize_fp16, &migraphx::quantize_fp16,
py::arg("prog"), py::arg("prog"),
......
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/target.hpp> #include <migraphx/target.hpp>
#include <utility>
#include <set>
#include <iomanip>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <fstream> #include <set>
#include <algorithm>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
instruction_ref insert_quant_ins(module& modl,
instruction_ref& ins,
shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
float scale = 1.0f,
float shift = 0.0f)
{
if(map_ins.count(ins) > 0)
{
return map_ins[ins];
}
if(ins->name() == "undefined")
{
return ins;
}
assert(ins->get_shape().type() == shape::float_type or
ins->get_shape().type() == shape::double_type or
ins->get_shape().type() == shape::int32_type or
ins->get_shape().type() == shape::half_type);
instruction_ref quant_ins{};
auto insert_loc = std::next(ins);
if(type == shape::int8_type)
{
auto scaled_ins = ins;
if(scale != 1.0f)
{
auto float_ins = scaled_ins;
if(scaled_ins->get_shape().type() != shape::float_type)
{
float_ins = modl.insert_instruction(
insert_loc,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
scaled_ins);
}
std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale);
auto l_scale = modl.add_literal(literal(float_ins->get_shape(), vec_scale));
scaled_ins = modl.insert_instruction(insert_loc, make_op("mul"), l_scale, float_ins);
}
auto shifted_ins = scaled_ins;
if(shift != 0.0f)
{
auto float_ins = shifted_ins;
if(shifted_ins->get_shape().type() != shape::float_type)
{
float_ins = modl.insert_instruction(
insert_loc,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
shifted_ins);
}
std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift);
auto l_shift = modl.add_literal(literal(float_ins->get_shape(), vec_shift));
shifted_ins = modl.insert_instruction(insert_loc, make_op("add"), l_shift, float_ins);
}
auto rounded_ins = modl.insert_instruction(insert_loc, make_op("round"), shifted_ins);
auto rounded_lens = rounded_ins->get_shape().lens();
auto max_clip = modl.add_literal(127.0f);
auto min_clip = modl.add_literal(-128.0f);
max_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), max_clip);
min_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), min_clip);
auto clipped_ins =
modl.insert_instruction(insert_loc, make_op("clip"), rounded_ins, min_clip, max_clip);
quant_ins = modl.insert_instruction(
insert_loc, make_op("convert", {{"target_type", type}}), clipped_ins);
}
else
{
quant_ins =
modl.insert_instruction(insert_loc, make_op("convert", {{"target_type", type}}), ins);
}
map_ins[ins] = quant_ins;
return quant_ins;
}
// This function is to convert any instructions specified in the input // This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator. // from double or float to float16 by inserting a convert operator.
// For the conversion, there could be cases of overflowing, but it // For the conversion, there could be cases of overflowing, but it
...@@ -119,337 +30,14 @@ instruction_ref insert_quant_ins(module& modl, ...@@ -119,337 +30,14 @@ instruction_ref insert_quant_ins(module& modl,
// truncate of the input to get the fp16. // truncate of the input to get the fp16.
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names) void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{ {
auto* mm = prog.get_main_module(); run_passes(prog,
std::unordered_map<instruction_ref, instruction_ref> map_fp16; {quantize_fp16_pass{ins_names},
for(auto ins : iterator_for(*mm)) eliminate_common_subexpression{},
{ dead_code_elimination{},
if(ins->name() == "@return") simplify_reshapes{},
break; dead_code_elimination{},
simplify_qdq{},
// all indicates every instruction is converted dead_code_elimination{}});
if((not contains(ins_names, "all")) and (not contains(ins_names, ins->name())))
{
continue;
}
shape::type_t orig_type = ins->get_shape().type();
// process all inputs, if input is a fp32 or fp64, convert it
// to a fp16 by adding a convert operator.
auto inputs = ins->inputs();
std::vector<instruction_ref> converted_inputs;
for(auto input : inputs)
{
auto s = input->get_shape();
if(s.type() == shape::float_type || s.type() == shape::double_type)
{
// if the input is a convert operator, uses its input
// as its current input
instruction_ref input_fp16{};
if(input->name() == "convert" and
input->inputs().front()->get_shape().type() == shape::half_type)
{
input_fp16 = input->inputs().front();
}
else
{
input_fp16 = insert_quant_ins(*mm, input, shape::half_type, map_fp16);
}
converted_inputs.push_back(input_fp16);
}
else
{
converted_inputs.push_back(input);
}
}
// no change for the input, go to the next instruction
if(inputs == converted_inputs)
{
continue;
}
auto op = ins->get_operator();
auto ins_shape = compute_shape(op, converted_inputs);
if(ins_shape.type() != orig_type)
{
// check the dead code case to avoid assert
bool output_empty = ins->outputs().empty();
auto ins_orig_type = mm->insert_instruction(
std::next(ins), make_op("convert", {{"target_type", orig_type}}), ins);
if(!output_empty)
{
mm->replace_instruction(ins, ins_orig_type);
}
}
mm->replace_instruction(ins, op, converted_inputs);
}
}
static void ins_quantize_int8(module& modl,
instruction_ref ins,
std::vector<instruction_ref>& converted_inputs,
const std::vector<std::pair<float, float>>& ins_quant_params)
{
auto orig_type = ins->get_shape().type();
auto inputs = ins->inputs();
if(ins->name() == "dot")
{
auto dot_op = any_cast<op::dot>(ins->get_operator());
float new_alpha = dot_op.alpha / (ins_quant_params[0].first * ins_quant_params[1].first);
float new_beta = dot_op.beta;
// We need additional checking about the quant_alpha value. If
// abs(quant_alpha) > 50 (some tmp value set here), we can convert
// it to an integer as the new_alpha in the quant_dot
float threshold = 50.0f;
if(fabs(new_alpha) >= threshold && fabs(new_beta) >= threshold)
{
int32_t quant_alpha = static_cast<int32_t>(std::round(new_alpha));
int32_t quant_beta = static_cast<int32_t>(std::round(new_beta));
if(shape::int32_type == orig_type)
{
modl.replace_instruction(
ins,
make_op("quant_dot", {{"alpha", quant_alpha}, {"beta", quant_beta}}),
converted_inputs);
}
else
{
auto quant_dot = modl.insert_instruction(
ins,
make_op("quant_dot", {{"alpha", quant_alpha}, {"beta", quant_beta}}),
converted_inputs);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), quant_dot);
}
}
// either alpha or beta cannot be quantized because of too big
// relative rounding error
else
{
if(converted_inputs.size() == 3)
{
converted_inputs.pop_back();
}
auto q_dot = modl.insert_instruction(
ins, make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), converted_inputs);
auto f_dot = modl.insert_instruction(
ins, make_op("convert", {{"target_type", to_value(shape::float_type)}}), q_dot);
auto c_shape = q_dot->get_shape();
std::vector<float> vec_alpha(c_shape.elements(), new_alpha);
auto l_alpha =
modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_alpha));
if(inputs.size() == 3 and dot_op.beta != 0.0f)
{
auto alpha_ab = modl.insert_instruction(ins, make_op("mul"), l_alpha, f_dot);
std::vector<float> vec_beta(c_shape.elements(), dot_op.beta);
auto l_beta =
modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta));
instruction_ref beta_c{};
if(orig_type != shape::float_type)
{
auto fp32_c = modl.insert_instruction(
ins,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
inputs.back());
beta_c = modl.insert_instruction(ins, make_op("mul"), l_beta, fp32_c);
}
else
{
beta_c = modl.insert_instruction(ins, make_op("mul"), l_beta, inputs.back());
}
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, make_op("add"), alpha_ab, beta_c);
}
else
{
auto f_res = modl.insert_instruction(ins, make_op("add"), alpha_ab, beta_c);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), f_res);
}
}
else
{
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, make_op("mul"), l_alpha, f_dot);
}
else
{
auto alpha_ab = modl.insert_instruction(ins, make_op("mul"), l_alpha, f_dot);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), alpha_ab);
}
}
}
}
else if(ins->name() == "convolution")
{
// Current MIOpen convolution does not support alpha and beta,
// so we need a separate multiply to adjust the output
auto conv_op = any_cast<op::convolution>(ins->get_operator());
auto padding = conv_op.padding;
auto stride = conv_op.stride;
auto dilation = conv_op.dilation;
auto padding_mode = conv_op.padding_mode;
auto group = conv_op.group;
auto adjust_factor = 1.0f / (ins_quant_params[0].first * ins_quant_params[1].first);
auto quant_conv = modl.insert_instruction(
ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs);
float threshold = 50.0f;
std::vector<float> vec_factor(quant_conv->get_shape().elements(), adjust_factor);
if(quant_conv->get_shape().type() == orig_type and adjust_factor >= threshold)
{
auto l_factor = modl.add_literal(
literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end()));
modl.replace_instruction(ins, make_op("mul"), quant_conv, l_factor);
}
// convert quant_conv output to float type, multiply the factor and
// conver back to original type
else
{
auto float_conv = modl.insert_instruction(
ins,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
quant_conv);
auto l_factor = modl.add_literal(literal(float_conv->get_shape(), vec_factor));
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, make_op("mul"), l_factor, float_conv);
}
else
{
auto adjusted_conv =
modl.insert_instruction(ins, make_op("mul"), l_factor, float_conv);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), adjusted_conv);
}
}
}
else
{
MIGRAPHX_THROW("QUANTIZE_INT8: does not support operator " + ins->name());
}
}
// int8 quantization is different from fp16 since int8 can only handle value
// -128 ~ 127. To convert the float or double to int8, we need a scale and
// a shift, then the convert can be done as v_int8 = fp * scale + shift.
// To simplify the changes, we consider shift as 0.0f for now.
void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names)
{
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
{
for(std::size_t i = 0; i < quant_params.size(); ++i)
{
auto param = quant_params.at(i);
std::cout << "ins_index = " << i << ", scale = " << param.first
<< ", shift = " << param.second << std::endl;
}
std::cout << std::endl;
}
// For now, we only support the int8 quantization of gemm and convolution
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(!std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
auto* mm = prog.get_main_module();
std::size_t quant_param_index = 0;
std::unordered_map<instruction_ref, instruction_ref> map_quant_ins;
std::unordered_map<instruction_ref, std::size_t> map_ins_index;
for(auto ins : iterator_for(*mm))
{
if(ins->name() == "@return")
break;
if(not contains(ins_names, ins->name()))
{
continue;
}
// for the dot operator, there could be 2 or 3 input arguments
// if the 3rd argument is available, convert it to an int32.
std::vector<instruction_ref> converted_inputs;
// process all inputs, if input is a fp32 or fp64, convert it
// to a int8 type by adding a convert operator and replace
// the operator with the corresponding int8 version
auto inputs = ins->inputs();
std::vector<std::pair<float, float>> ins_quant_params;
for(auto input : inputs)
{
// calculate the index of each instruction to be quantized
std::size_t ins_index =
(map_ins_index.count(input) > 0) ? map_ins_index[input] : quant_param_index++;
map_ins_index[input] = ins_index;
auto param = quant_params[map_ins_index[input]];
ins_quant_params.push_back(param);
// In general, the target_type is int8, but for the dot
// operation, if it has 3 inputs, then the last one should
// be converted to int32_type
shape::type_t quant_type = shape::int8_type;
if((ins->name() == "dot") and (inputs.size() == 3) and (input == inputs.back()))
{
quant_type = shape::int32_type;
}
auto s = input->get_shape();
if((s.type() == shape::float_type or s.type() == shape::double_type or
s.type() == shape::half_type or s.type() == shape::int32_type) and
s.type() != quant_type)
{
// if the input is a convert operator, uses its input
// as its current input
instruction_ref quant_input{};
if(input->name() == "convert" and
input->inputs().front()->get_shape().type() == quant_type)
{
quant_input = input->inputs().front();
// the scale in this case is not used, so tune the scale
// to 1.0f for this parameter
ins_quant_params.back() = std::pair<float, float>(1.0f, 0.0f);
}
else
{
quant_input = insert_quant_ins(
*mm, input, quant_type, map_quant_ins, param.first, param.second);
}
converted_inputs.push_back(quant_input);
}
else
{
converted_inputs.push_back(input);
}
}
// no change for the input, go to the next instruction
if(inputs == converted_inputs)
{
continue;
}
ins_quantize_int8(*mm, ins, converted_inputs, ins_quant_params);
}
if(quant_param_index != quant_params.size())
{
MIGRAPHX_THROW("QUANTIZE_INT8: number of scales does not match");
}
} }
void quantize_int8(program& prog, void quantize_int8(program& prog,
...@@ -457,87 +45,14 @@ void quantize_int8(program& prog, ...@@ -457,87 +45,14 @@ void quantize_int8(program& prog,
const std::vector<parameter_map>& calibration, const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names) const std::vector<std::string>& ins_names)
{ {
// insert capture operator std::set<std::string> op_names = {"convolution", "dot"};
auto cap_prog = prog;
auto int8_quant_params = capture_arguments(cap_prog, t, ins_names);
// use the calibration data to compute the quantization scale
cap_prog.compile(t);
// use all calibration data to run the program to calculate the
// quantization scale and shift
for(auto&& arg : calibration)
{
parameter_map m;
for(auto&& x : cap_prog.get_parameter_shapes())
{
if(arg.count(x.first) > 0)
{
assert(x.second == arg.at(x.first).get_shape());
m[x.first] = t.copy_to(arg.at(x.first));
}
else
{
m[x.first] = t.allocate(x.second);
}
}
cap_prog.eval(m);
}
quantize_int8_impl(prog, *int8_quant_params, ins_names);
}
// For the input of each input argument, we need to insert a
// capture operator to compute the scale and shift
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func)
{
auto* mm = prog.get_main_module();
size_t num_quant_params = 0;
// the int8 quantization only support dot and convolution
std::set<std::string> op_names = {"dot", "convolution"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end()); std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(!std::includes( if(!std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end())) op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{ {
MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported"); MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*mm))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
instruction_ref new_ins{};
if(ins_map.count(input) > 0)
{
new_ins = ins_map[input];
}
else
{
new_ins = mm->insert_instruction(
std::next(input), op::capture{num_quant_params++, func}, input);
ins_map[input] = new_ins;
}
new_args.push_back(new_ins);
}
instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args);
} }
return num_quant_params;
}
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names)
{
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params = std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>(); std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>(); std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
...@@ -545,7 +60,6 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st ...@@ -545,7 +60,6 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st
auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](std::size_t ins_index, auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](std::size_t ins_index,
std::vector<argument> args) { std::vector<argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f}; std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not // scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0 // consider shift, so set shift to 0
std::vector<float> vec_val; std::vector<float> vec_val;
...@@ -568,12 +82,56 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st ...@@ -568,12 +82,56 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st
int8_quant_params->at(ins_index) = param_pair; int8_quant_params->at(ins_index) = param_pair;
}; };
auto num_params = capture_arguments(prog, ins_names, calc_quant_params); // pass to add capture argument op
std::size_t param_num = 0;
run_passes(prog, {capture_arguments_pass{ins_names, calc_quant_params, &param_num}});
int8_quant_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(param_num, 0.0f);
// use the calibration data to compute the quantization scale
auto capture_prog = prog;
capture_prog.compile(t);
// use all calibration data to run the program to calculate the
// quantization scale and shift
for(auto&& arg : calibration)
{
parameter_map m;
for(auto&& x : capture_prog.get_parameter_shapes())
{
if(arg.count(x.first) > 0)
{
assert(x.second == arg.at(x.first).get_shape());
m[x.first] = t.copy_to(arg.at(x.first));
}
else
{
m[x.first] = t.allocate(x.second);
}
}
capture_prog.eval(m);
}
int8_quant_params->resize(num_params, std::pair<float, float>(64.0f, 0.0f)); // print the quantization parameters in only the main module
max_abs_vals->resize(num_params, 0.0f); if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
{
for(std::size_t i = 0; i < int8_quant_params->size(); ++i)
{
auto param = int8_quant_params->at(i);
std::cout << "ins_index = " << i << ", scale = " << param.first
<< ", shift = " << param.second << std::endl;
}
std::cout << std::endl;
}
return int8_quant_params; run_passes(prog,
{quantize_int8_pass{ins_names, *int8_quant_params},
eliminate_common_subexpression{},
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{},
simplify_qdq{},
dead_code_elimination{}});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void quantize_module(module& m, const std::vector<std::string>& ins_names)
{
for(auto ins : iterator_for(m))
{
// instructions are not in the set to be quantized
if(not(contains(ins_names, ins->name()) or contains(ins_names, "all")))
continue;
// skip return and convert instructions
if(contains({"@return", "convert"}, ins->name()))
continue;
if(ins->inputs().empty())
continue;
auto mod_inputs = ins->module_inputs();
auto s = ins->get_shape();
// Convert back to original type before quantizing the inputs
if(mod_inputs.empty())
{
auto r = m.insert_instruction(
std::next(ins), make_op("convert", {{"target_type", s.type()}}), ins);
m.replace_instruction(ins, r);
}
// Convert each of the inputs that are floating point to fp16
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
auto input_type = input->get_shape().type();
if(input_type != shape::float_type and input_type != shape::double_type)
return input;
return m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::half_type}}), input);
});
// Replace inputs
m.replace_instruction(ins, ins->get_operator(), inputs, mod_inputs);
}
}
void quantize_fp16_pass::apply(module& m) const { quantize_module(m, ins_names); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/operation.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <numeric>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
static std::vector<shape::type_t>& get_quantizable_type()
{
static std::vector<shape::type_t> quantable_types = {
shape::float_type, shape::double_type, shape::half_type};
return quantable_types;
}
void quantize_int8_pass::apply(module& m) const // NOLINT
{
const auto& quantizable_types = get_quantizable_type();
for(auto ins : iterator_for(m))
{
if(ins->name() != "capture")
continue;
auto op_val = ins->get_operator().to_value();
assert(op_val.contains("ins_index"));
auto param_index = op_val.at("ins_index").to<std::size_t>();
auto param = quant_params[param_index];
auto input = ins->inputs().front();
auto s = input->get_shape();
if(contains(quantizable_types, s.type()) and s.type() != shape::int8_type)
{
auto zero_point = m.add_literal(static_cast<int8_t>(param.second));
auto scale = m.add_literal(literal({s.type()}, {1.0f / param.first}));
const auto& lens = s.lens();
scale =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), scale);
zero_point = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), zero_point);
auto q_in =
m.insert_instruction(ins, make_op("quantizelinear"), input, scale, zero_point);
auto dq_in =
m.insert_instruction(ins, make_op("dequantizelinear"), q_in, scale, zero_point);
m.replace_instruction(ins, dq_in);
}
}
}
void capture_arguments_pass::apply(module& m) const // NOLINT
{
assert(param_index != nullptr);
for(auto ins : iterator_for(m))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
auto new_in = m.insert_instruction(ins, op::capture{(*param_index)++, f}, input);
new_args.push_back(new_in);
}
m.replace_instruction(ins, ins->get_operator(), new_args);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -16,10 +16,8 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n) ...@@ -16,10 +16,8 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
auto bstride = s.strides()[n + 1]; auto bstride = s.strides()[n + 1];
auto blen = s.lens()[n + 1]; auto blen = s.lens()[n + 1];
if(astride == bstride * blen) if(astride == bstride * blen or alen == 1)
{
new_lens.push_back(alen * blen); new_lens.push_back(alen * blen);
}
} }
if(new_lens.size() != shapes.size()) if(new_lens.size() != shapes.size())
return false; return false;
...@@ -37,12 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n) ...@@ -37,12 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
return true; return true;
} }
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n) void reduce_dim1(std::vector<shape>& shapes)
{ {
while(reduce_dim(shapes, n) and n < shapes.size()) if(std::any_of(shapes.begin(), shapes.end(), [&](const auto& s) {
return s.lens().size() < 2 or s.lens().back() != 1;
}))
return;
for(auto& s : shapes)
{ {
auto lens = s.lens();
auto strides = s.strides();
lens.pop_back();
strides.pop_back();
s = shape{s.type(), lens, strides};
} }
}
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{
while(reduce_dim(shapes, n) and n < shapes.size()) {}
return n + 1; return n + 1;
} }
void reduce_dim_all(std::vector<shape>& shapes) void reduce_dim_all(std::vector<shape>& shapes)
...@@ -50,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes) ...@@ -50,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes)
std::size_t n = 0; std::size_t n = 0;
while(n < shapes.front().lens().size() - 1) while(n < shapes.front().lens().size() - 1)
n = reduce_dim_all(shapes, n); n = reduce_dim_all(shapes, n);
reduce_dim1(shapes);
} }
std::vector<std::size_t> base_lens(const std::vector<shape>& shapes) std::vector<std::size_t> base_lens(const std::vector<shape>& shapes)
......
#include <migraphx/register_target.hpp>
#include <unordered_map> #include <unordered_map>
#include <migraphx/register_target.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -11,7 +11,17 @@ std::unordered_map<std::string, target>& target_map() ...@@ -11,7 +11,17 @@ std::unordered_map<std::string, target>& target_map()
} }
void register_target(const target& t) { target_map()[t.name()] = t; } void register_target(const target& t) { target_map()[t.name()] = t; }
target make_target(const std::string& name) { return target_map().at(name); }
target make_target(const std::string& name)
{
const auto it = target_map().find(name);
if(it == target_map().end())
{
MIGRAPHX_THROW("Requested target '" + name + "' is not enabled or not supported");
}
return it->second;
}
std::vector<std::string> get_targets() std::vector<std::string> get_targets()
{ {
std::vector<std::string> result; std::vector<std::string> result;
......
#include <migraphx/remap.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/add.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct find_dot_add
{
auto matcher() const
{
return match::name("add")(match::any_of(
match::args(match::name("dot")(match::nargs(2)).bind("dot"), match::any().bind("a")),
match::args(match::used_once().bind("a"),
match::name("dot")(match::nargs(2)).bind("dot"))));
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
auto a_ins = r.instructions["a"];
auto dot = any_cast<op::dot>(dot_ins->get_operator());
dot.beta = 1;
p.replace_instruction(ins, dot, dot_ins->inputs()[0], dot_ins->inputs()[1], a_ins);
}
};
} // namespace
void remap::apply(module& p) const { match::find_matches(p, find_dot_add{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment