Commit 870a396b authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 228b665c d309e02f
......@@ -398,7 +398,7 @@ std::vector<argument> generic_eval(const program& p,
return generic_eval(mm, ctx, params, {}, make_trace);
}
std::vector<argument> program::eval(parameter_map params) const
std::vector<argument> program::eval(parameter_map params, execution_environment exec_env) const
{
auto& ctx = this->impl->ctx;
#ifndef NDEBUG
......@@ -423,6 +423,12 @@ std::vector<argument> program::eval(parameter_map params) const
#endif
auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
std::vector<argument> ret;
if(exec_env.async)
{
ctx.wait_for(exec_env.queue);
}
if(trace_level > 0)
{
......@@ -434,49 +440,56 @@ std::vector<argument> program::eval(parameter_map params) const
ins_out[x] = ss.str();
});
return generic_eval(*this,
ctx,
std::move(params),
with_check_context([&](auto& ins, auto f, auto&& check_context) {
ctx.finish();
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
timer t{};
auto result = check_context(f);
double t1 = t.record<milliseconds>();
ctx.finish();
double t2 = t.record<milliseconds>();
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
if(trace_level > 1 and ins->name().front() != '@' and
ins->name() != "load" and not result.empty())
{
target tgt = make_target(this->impl->target_name);
auto buffer = tgt.copy_from(result);
if(trace_level == 2)
{
std::cout << "Output has "
<< to_string_range(classify_argument(buffer))
<< std::endl;
std::cout << "Output: ";
preview_argument(std::cout, buffer);
std::cout << std::endl;
}
else
{
std::cout << "Output: " << buffer << std::endl;
}
}
return result;
}));
ret = generic_eval(*this,
ctx,
std::move(params),
with_check_context([&](auto& ins, auto f, auto&& check_context) {
ctx.finish();
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
timer t{};
auto result = check_context(f);
double t1 = t.record<milliseconds>();
ctx.finish();
double t2 = t.record<milliseconds>();
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
if(trace_level > 1 and ins->name().front() != '@' and
ins->name() != "load" and not result.empty())
{
target tgt = make_target(this->impl->target_name);
auto buffer = tgt.copy_from(result);
if(trace_level == 2)
{
std::cout << "Output has "
<< to_string_range(classify_argument(buffer))
<< std::endl;
std::cout << "Output: ";
preview_argument(std::cout, buffer);
std::cout << std::endl;
}
else
{
std::cout << "Output: " << buffer << std::endl;
}
}
return result;
}));
}
else
{
return generic_eval(*this,
ctx,
std::move(params),
with_check_context([&](auto&, auto f, auto&& check_context) {
return check_context(f);
}));
ret = generic_eval(*this,
ctx,
std::move(params),
with_check_context([&](auto&, auto f, auto&& check_context) {
return check_context(f);
}));
}
if(exec_env.async)
{
ctx.finish_on(exec_env.queue);
}
return ret;
}
const int program_file_version = 5;
......@@ -841,6 +854,25 @@ void program::print_graph(std::ostream& os, bool brief) const
mm->print_graph(os, brief);
}
void program::print_py(std::ostream& os) const
{
auto vec_modules = this->get_modules();
std::unordered_map<instruction_ref, std::string> names;
os << "p = migraphx.program()\n";
for(auto& mod : vec_modules)
{
std::string var_name = "m" + mod->name();
os << var_name << " = ";
if(mod->name() == "main")
os << "p.get_main_module()";
else
os << "p.create_module(\"" << mod->name() << "\");";
os << std::endl;
names = mod->print_py(os, var_name, names);
os << std::endl;
}
}
void program::print_cpp(std::ostream& os) const
{
auto vec_modules = this->get_modules();
......
......@@ -355,6 +355,23 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
}
return p.eval(pm);
})
.def("run_async",
[](migraphx::program& p,
py::dict params,
std::uintptr_t stream,
std::string stream_name) {
migraphx::parameter_map pm;
for(auto x : params)
{
std::string key = x.first.cast<std::string>();
py::buffer b = x.second.cast<py::buffer>();
py::buffer_info info = b.request();
pm[key] = migraphx::argument(to_shape(info), info.ptr);
}
migraphx::execution_environment exec_env{
migraphx::any_ptr(reinterpret_cast<void*>(stream), stream_name), true};
return p.eval(pm, exec_env);
})
.def("sort", &migraphx::program::sort)
.def("print", [](const migraphx::program& p) { std::cout << p << std::endl; })
.def("__eq__", std::equal_to<migraphx::program>{})
......
......@@ -73,7 +73,7 @@ void insert_submod_allocations(instruction_ref ins, module& mod, const allocatio
name_shapes.insert(ps.begin(), ps.end());
}
for(auto& pn : name_shapes)
for(const auto& pn : name_shapes)
{
const auto& s = pn.second;
instruction_ref output{};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/dfor.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_batchnorm::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "batch_norm_inference")
continue;
// Get scale, bias, mean, variance from inputs
auto gamma = ins->inputs()[1]->eval();
auto bias = ins->inputs()[2]->eval();
auto mean = ins->inputs()[3]->eval();
auto variance = ins->inputs()[4]->eval();
if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
continue;
std::vector<std::size_t> lens = ins->inputs()[1]->get_shape().lens();
shape s{ins->get_shape().type(), lens};
// Get epsilon
auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon;
argument a{s};
argument b{s};
visit_all(gamma, bias, mean, variance, a, b)(
[&](auto gamma2, auto bias2, auto mean2, auto variance2, auto a2, auto b2) {
dfor(a.get_shape().elements())(
[&](std::size_t c) { a2[c] = gamma2[c] / std::sqrt(variance2[c] + epsilon); });
dfor(b.get_shape().elements())([&](std::size_t c) {
b2[c] = bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
});
});
auto broadcast = op::broadcast{1, ins->get_shape().lens()};
auto a_ins = m.add_literal({a.get_shape(), a.data()});
auto a_broadcast = m.insert_instruction(ins, broadcast, a_ins);
auto mul = m.insert_instruction(ins, make_op("mul"), ins->inputs().front(), a_broadcast);
auto b_ins = m.add_literal({b.get_shape(), b.data()});
auto b_broadcast = m.insert_instruction(ins, broadcast, b_ins);
auto add = m.insert_instruction(ins, make_op("add"), mul, b_broadcast);
m.replace_instruction(ins, add);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -46,9 +46,6 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -95,7 +92,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process sequence length
instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined")
if((args.size() >= 5) and not args[4]->is_undefined())
{
seq_lens = args[4];
}
......@@ -120,7 +117,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process bias
instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
if(args.size() >= 4 and not args[3]->is_undefined())
{
bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
......@@ -132,7 +129,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined")
if(args.size() == 6 and not args[5]->is_undefined())
{
ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
......@@ -198,14 +195,14 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process bias and initial hidden state
instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
if(args.size() >= 4 and not args[3]->is_undefined())
{
bias = args[3];
}
// process intial hidden state
instruction_ref ih;
if(args.size() == 6 && args[5]->name() != "undefined")
if(args.size() == 6 and not args[5]->is_undefined())
{
ih = args[5];
}
......@@ -401,7 +398,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// process sequence length
instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined")
if((args.size() >= 5) and not args[4]->is_undefined())
{
seq_lens = args[4];
}
......@@ -426,7 +423,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// bias
instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
if(args.size() >= 4 and not args[3]->is_undefined())
{
bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
......@@ -437,7 +434,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// intial hidden state
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined")
if(args.size() == 6 and not args[5]->is_undefined())
{
ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
......@@ -504,14 +501,14 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// bias
instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
if(args.size() >= 4 and not args[3]->is_undefined())
{
bias = args[3];
}
// intial hidden state
instruction_ref ih{};
if(args.size() == 6 && args[5]->name() != "undefined")
if(args.size() == 6 and not args[5]->is_undefined())
{
ih = args[5];
}
......@@ -787,7 +784,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process sequence length
instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined")
if((args.size() >= 5) and not args[4]->is_undefined())
{
seq_lens = args[4];
}
......@@ -816,7 +813,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process bias
instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
if(args.size() >= 4 and not args[3]->is_undefined())
{
bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
......@@ -827,7 +824,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process intial hidden state, it is the 6th argument
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() >= 6 && args[5]->name() != "undefined")
if(args.size() >= 6 and not args[5]->is_undefined())
{
ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
......@@ -843,7 +840,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process initial cell value
instruction_ref ic_forward{};
instruction_ref ic_reverse{};
if(args.size() >= 7 && args[6]->name() != "undefined")
if(args.size() >= 7 and not args[6]->is_undefined())
{
ic_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]);
......@@ -859,7 +856,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process weight of the peephole
instruction_ref pph_forward = m.end();
instruction_ref pph_reverse = m.end();
if(args.size() == 8 && args[7]->name() != "undefined")
if(args.size() == 8 and not args[7]->is_undefined())
{
pph_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]);
......@@ -943,14 +940,14 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// bias
instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
if(args.size() >= 4 and not args[3]->is_undefined())
{
bias = args[3];
}
// initial hidden state
instruction_ref ih{};
if(args.size() >= 6 && args[5]->name() != "undefined")
if(args.size() >= 6 and not args[5]->is_undefined())
{
ih = args[5];
}
......@@ -961,7 +958,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// initial cell value
instruction_ref ic{};
if(args.size() >= 7 && args[6]->name() != "undefined")
if(args.size() >= 7 and not args[6]->is_undefined())
{
ic = args[6];
}
......@@ -972,7 +969,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process weight of the peephole
instruction_ref pph = m.end();
if(args.size() == 8 && args[7]->name() != "undefined")
if(args.size() == 8 and not args[7]->is_undefined())
{
pph = args[7];
}
......
......@@ -71,6 +71,19 @@ struct shape_impl
{
}
shape_impl(shape::type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::size_t> opts)
: m_type(t)
{
assert(mins.size() == maxes.size() and maxes.size() == opts.size());
for(size_t i = 0; i < mins.size(); ++i)
{
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], opts[i]});
}
}
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape::type_t m_type;
......@@ -224,6 +237,14 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
{
}
shape::shape(type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::size_t> opts)
: impl(std::make_shared<shape_impl>(t, std::move(mins), std::move(maxes), std::move(opts)))
{
}
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
......@@ -244,6 +265,15 @@ const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
std::size_t shape::ndim() const
{
if(this->dynamic())
{
return dyn_dims().size();
}
return lens().size();
}
std::size_t shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const
......@@ -437,6 +467,16 @@ shape shape::with_type(type_t t) const
return {c};
}
shape shape::to_dynamic() const
{
if(this->dynamic())
{
return *this;
}
std::vector<std::size_t> zeroes(this->ndim(), 0);
return {type(), lens(), lens(), zeroes};
}
std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); }
......@@ -464,15 +504,36 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max;
bool shape::dynamic_dimension::has_optimal() const { return opt != 0; }
template <class Self, class F>
auto shape::dynamic_dimension::reflect(Self& self, F f)
shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x)
{
this->min += x;
this->max += x;
if(this->opt != 0)
{
this->opt += x;
};
return *this;
}
shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t& x)
{
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt"));
assert(this->min >= x);
assert(this->max >= x);
this->min -= x;
this->max -= x;
if(this->opt != 0)
{
assert(this->opt >= x);
this->opt -= x;
}
return *this;
}
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
return (x.min == y.min and x.max == y.max and x.opt == y.opt);
// don't check opt if both are fixed
return (x.min == y.min and x.max == y.max and
((x.is_fixed() and y.is_fixed()) or (x.opt == y.opt)));
}
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
......@@ -485,6 +546,31 @@ std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
return os;
}
bool operator==(const shape::dynamic_dimension& x, const std::size_t& y)
{
return x.min == y and x.max == y;
}
bool operator==(const std::size_t& x, const shape::dynamic_dimension& y) { return y == x; }
bool operator!=(const shape::dynamic_dimension& x, const std::size_t& y) { return not(x == y); }
bool operator!=(const std::size_t& x, const shape::dynamic_dimension& y) { return not(x == y); }
shape::dynamic_dimension operator+(const shape::dynamic_dimension& x, const std::size_t& y)
{
auto dd = x;
return dd += y;
}
shape::dynamic_dimension operator+(const std::size_t& x, const shape::dynamic_dimension& y)
{
return y + x;
}
shape::dynamic_dimension operator-(const shape::dynamic_dimension& x, const std::size_t& y)
{
auto dd = x;
return dd -= y;
}
bool operator==(const shape& x, const shape& y)
{
if(x.dynamic() and y.dynamic())
......
......@@ -57,12 +57,14 @@ auto conv_const_weights()
auto reduction() { return match::name_contains("reduce"); }
// conv(x, w) * a => conv(x, a * w)
struct find_mul_conv
{
auto matcher() const
{
return match::name("mul")(match::either_arg(0, 1)(conv_const_weights().bind("conv"),
match::name("broadcast").bind("a")));
return match::name("mul")(
match::either_arg(0, 1)(conv_const_weights().bind("conv"),
match::name("broadcast", "multibroadcast").bind("a")));
}
void apply(module& m, const match::matcher_result& r) const
......@@ -72,14 +74,35 @@ struct find_mul_conv
auto a_ins = r.instructions["a"];
auto w_ins = r.instructions["w"];
auto broadcast_op = any_cast<op::broadcast>(a_ins->get_operator());
if(broadcast_op.axis != 1)
const auto& a_input_lens = a_ins->inputs().front()->get_shape().lens();
std::size_t num_not_one_dims = std::count_if(
a_input_lens.cbegin(), a_input_lens.cend(), [](auto dim) { return dim != 1; });
if(num_not_one_dims > 1)
return;
// check broadcasted along channels
const auto& a_lens = a_ins->get_shape().lens();
const auto& a_strides = a_ins->get_shape().strides();
auto is_broadcasted_axis = [](auto len, auto stride) { return len == 1 or stride == 0; };
if(a_strides.at(1) != 1)
return;
if(not is_broadcasted_axis(a_lens.front(), a_strides.front()))
return;
if(not std::equal(a_lens.begin() + 2,
a_lens.end(),
a_strides.begin() + 2,
a_strides.end(),
is_broadcasted_axis))
return;
auto sq = m.insert_instruction(ins, make_op("squeeze"), a_ins->inputs().front());
auto new_a = m.insert_instruction(
ins,
make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}),
a_ins->inputs().front());
ins, make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}), sq);
auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = m.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
......@@ -412,6 +435,24 @@ struct find_concat_op
}
};
void move_instructions_back(module& m, instruction_ref pos, std::vector<instruction_ref> inss)
{
auto start = range(m.begin(), pos);
for(auto ins : iterator_for(start))
{
auto it = std::find(inss.begin(), inss.end(), ins);
if(it != inss.end())
inss.erase(it);
}
for(auto ins : inss)
{
if(not m.has_instruction(ins))
continue;
move_instructions_back(m, pos, ins->inputs());
m.move_instruction(ins, pos);
}
}
std::vector<instruction_ref> get_splits(instruction_ref ins)
{
std::vector<instruction_ref> result;
......@@ -587,8 +628,7 @@ struct find_splits
}))
return;
for(auto data : data_args)
m.move_instructions(data, ins);
move_instructions_back(m, ins, data_args);
auto slice_op = any_cast<op::slice>(splits.front()->get_operator());
assert(not slice_op.axes.empty());
......@@ -787,7 +827,7 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot"));
auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution"));
return not(dots < 2 and convs < 2);
return (dots >= 2 or convs >= 2);
}
struct find_conv_dot_horiz_fusion
......@@ -841,8 +881,7 @@ struct find_conv_dot_horiz_fusion
concat_axis = axis;
}
for(auto arg : args)
m.move_instructions(arg, input);
move_instructions_back(m, input, args);
// TODO: Check if axes match
auto concat =
m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
......@@ -894,6 +933,73 @@ struct find_div_const
}
};
struct find_unit_ops
{
auto matcher() const
{
auto mul_1 = match::name("mul")(
match::either_arg(0, 1)(match::has_value(1.0f), match::any().bind("x")));
auto div_1 =
match::name("div")(match::args(match::any().bind("x"), match::has_value(1.0f)));
auto add_0 = match::name("add")(
match::either_arg(0, 1)(match::has_value(0.0f, 1e-12), match::any().bind("x")));
auto sub_0 =
match::name("sub")(match::args(match::any().bind("x"), match::has_value(0.0f)));
return match::any_of(mul_1, div_1, add_0, sub_0);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto c_in = r.instructions["x"];
m.replace_instruction(ins, c_in);
}
};
struct find_neg_unit_ops
{
auto matcher() const
{
auto mul_neg_1 = match::name("mul")(
match::either_arg(0, 1)(match::has_value(-1.0f), match::any().bind("x")));
auto div_neg_1 =
match::name("div")(match::args(match::any().bind("x"), match::has_value(-1.0f)));
auto sub_0 =
match::name("sub")(match::args(match::has_value(0.0f), match::any().bind("x")));
return match::any_of(mul_neg_1, div_neg_1, sub_0);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto c_in = r.instructions["x"];
auto neg = m.add_instruction(make_op("neg"), c_in);
m.replace_instruction(ins, neg);
}
};
struct find_zero_ops
{
auto matcher() const
{
auto mul_zero = match::name("mul")(
match::either_arg(0, 1)(match::has_value(0.0f).bind("x"), match::any()));
auto div_zero =
match::name("div")(match::args(match::has_value(0.0f).bind("x"), match::any()));
return match::any_of(mul_zero, div_zero);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto zero_ins = r.instructions["x"];
m.replace_instruction(ins, zero_ins);
}
};
struct find_sub_const
{
auto matcher() const
......@@ -959,11 +1065,23 @@ struct find_split_reshape
return;
}
// Only want to apply this optimization if each split output is followed by
// a contiguous op and a reshape
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
if(i->outputs().size() == 1)
{
auto cont = i->outputs().front();
return cont->outputs().size() != 1;
}
return false;
}))
{
return;
}
std::vector<instruction_ref> vec_rsp(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_rsp.begin(), [](auto i) {
assert(i->outputs().size() == 1);
auto cont = i->outputs().front();
assert(cont->outputs().size() == 1);
return cont->outputs().front();
});
......@@ -1110,6 +1228,9 @@ void simplify_algebra::apply(module& m) const
find_mul_conv{},
find_mul_slice_conv{},
find_mul_add{},
find_unit_ops{},
find_neg_unit_ops{},
find_zero_ops{},
find_dot_add{},
find_div_const{},
find_sub_const{},
......
......@@ -763,16 +763,23 @@ struct find_transpose_slice
// Compute axis before transpose to use for unsqueeze
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin();
// Make unsqeeze
// Make unsqueeze
std::vector<int64_t> steps(sdistance.size());
std::transform(
slice.axes.begin(),
slice.axes.end(),
sdistance.begin(),
steps.begin(),
[&](const auto ax, const auto sdis) { return ins->get_shape().lens().at(ax) / sdis; });
auto unsqueeze = m.insert_instruction(
ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", sdistance}}), ins->inputs());
ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", steps}}), ins->inputs());
// Make transpose
std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) {
if(i > preaxis)
if(i >= preaxis)
return i + 1;
return i;
});
perm.insert(perm.begin(), preaxis + 1);
perm.insert(perm.begin(), preaxis);
auto transpose =
m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze);
// Slice and squeeze
......
......@@ -35,6 +35,7 @@ add_library(migraphx_cpu
dnnl.cpp
eltwise.cpp
erf.cpp
fmod.cpp
fuse_ops.cpp
gather.cpp
gemm.cpp
......@@ -42,6 +43,7 @@ add_library(migraphx_cpu
logsoftmax.cpp
lowering.cpp
lrn.cpp
mod.cpp
preallocate.cpp
pooling.cpp
reduction.cpp
......
......@@ -51,19 +51,18 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
auto r = s0;
if(s0 == s1 and s0.packed())
{
r = s0;
}
else if(s0.packed() != s1.packed())
{
r = s0.packed() ? s0 : s1;
}
else if(s0.broadcasted() != s1.broadcasted())
{
r = s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else
{
r = {s0.type(), s0.lens()};
if(s0.packed() != s1.packed())
{
r = s0.packed() ? s0 : s1;
}
else if(s0.broadcasted() != s1.broadcasted())
{
r = s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else
{
r = {s0.type(), s0.lens()};
}
}
// Call to get_primitive to make sure an algo is available
this->get_primitive(this->to_memory_desc(r, inputs));
......
......@@ -21,22 +21,16 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_COS_HPP
#define MIGRAPHX_GUARD_RTGLIB_COS_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/cos.hpp>
#include <migraphx/config.hpp>
#include <migraphx/cpu/pointwise.hpp>
#include <migraphx/op/fmod.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace cpu {
struct hip_cos : unary_device<hip_cos, device::cos>
{
};
template struct cpu_binary<op::fmod>;
} // namespace gpu
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -26,7 +26,6 @@
#include <migraphx/instruction.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
......@@ -43,6 +42,8 @@
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/mod.hpp>
#include <migraphx/op/fmod.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp>
......@@ -214,55 +215,6 @@ struct cpu_pad
};
MIGRAPHX_REGISTER_OP(cpu_pad)
struct leaky_relu_op
{
op::leaky_relu op;
std::string name() const { return "cpu::leaky_relu"; }
auto fcn() const
{
auto a = op.alpha;
return [a](auto x) { return x > 0 ? x : x * a; };
}
};
template <typename Op>
struct cpu_unary2 : auto_register_op<cpu_unary2<Op>>
{
cpu_unary2() = default;
template <class T>
cpu_unary2(T pop) : op(Op{std::move(pop)})
{
}
Op op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op.op, f);
}
std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(1);
const auto& s = inputs.at(0);
return {s.type(), s.lens()};
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
assert(input.get_shape().standard());
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
});
return result;
}
};
template struct cpu_unary2<leaky_relu_op>;
struct cpu_rnn_var_sl_last_output
{
op::rnn_var_sl_last_output op;
......
......@@ -21,22 +21,16 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_EXP_HPP
#define MIGRAPHX_GUARD_RTGLIB_EXP_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/exp.hpp>
#include <migraphx/config.hpp>
#include <migraphx/cpu/pointwise.hpp>
#include <migraphx/op/mod.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace cpu {
struct hip_exp : unary_device<hip_exp, device::exp>
{
};
template struct cpu_binary<op::mod>;
} // namespace gpu
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -38,12 +38,10 @@
#include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/schedule.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
......@@ -79,8 +77,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
eliminate_identity{},
eliminate_pad{},
dead_code_elimination{},
rewrite_batchnorm{},
dead_code_elimination{},
rewrite_rnn{},
dead_code_elimination{},
eliminate_common_subexpression{},
......
......@@ -39,81 +39,9 @@ file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(migraphx_kernels ${KERNEL_FILES})
add_library(migraphx_device
device/acos.cpp
device/acosh.cpp
device/add.cpp
device/add_clip.cpp
device/add_relu.cpp
device/add_sigmoid.cpp
device/add_tanh.cpp
device/argmax.cpp
device/argmin.cpp
device/asin.cpp
device/asinh.cpp
device/atan.cpp
device/atanh.cpp
device/ceil.cpp
device/clip.cpp
device/concat.cpp
device/contiguous.cpp
device/convert.cpp
device/cos.cpp
device/cosh.cpp
device/div.cpp
device/equal.cpp
device/erf.cpp
device/exp.cpp
device/fill.cpp
device/floor.cpp
device/gather.cpp
device/gelu.cpp
device/greater.cpp
device/int8_gemm_pack.cpp
device/layernorm.cpp
device/less.cpp
device/log.cpp
device/logical_and.cpp
device/logical_or.cpp
device/logical_xor.cpp
device/logsoftmax.cpp
device/max.cpp
device/min.cpp
device/mul.cpp
device/mul_add.cpp
device/mul_add_relu.cpp
device/multinomial.cpp
device/nonzero.cpp
device/pad.cpp
device/pow.cpp
device/prelu.cpp
device/prefix_scan_sum.cpp
device/recip.cpp
device/reduce_max.cpp
device/reduce_mean.cpp
device/reduce_min.cpp
device/reduce_sum.cpp
device/reduce_prod.cpp
device/relu.cpp
device/reverse.cpp
device/rnn_variable_seq_lens.cpp
device/round.cpp
device/rsqrt.cpp
device/scatter.cpp
device/sigmoid.cpp
device/sign.cpp
device/sin.cpp
device/sinh.cpp
device/softmax.cpp
device/sqdiff.cpp
device/sqrt.cpp
device/sub.cpp
device/tan.cpp
device/tanh.cpp
device/topk.cpp
device/unary_not.cpp
device/where.cpp
)
file(GLOB DEVICE_GPU_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
add_library(migraphx_device ${DEVICE_GPU_SRCS})
add_library(compile_for_gpu INTERFACE)
target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored)
......@@ -150,20 +78,14 @@ add_library(migraphx_gpu
allocation_model.cpp
argmax.cpp
argmin.cpp
batch_norm_inference.cpp
clip.cpp
code_object_op.cpp
compile_ops.cpp
compile_gen.cpp
compile_hip.cpp
compile_hip_code_object.cpp
compile_miopen.cpp
compiler.cpp
concat.cpp
convert.cpp
convolution.cpp
deconvolution.cpp
device_name.cpp
elu.cpp
fuse_mlir.cpp
fuse_ops.cpp
gather.cpp
......@@ -176,7 +98,6 @@ add_library(migraphx_gpu
logsoftmax.cpp
loop.cpp
lrn.cpp
leaky_relu.cpp
mlir.cpp
multinomial.cpp
nonzero.cpp
......@@ -186,13 +107,11 @@ add_library(migraphx_gpu
pad.cpp
perfdb.cpp
pooling.cpp
quant_convolution.cpp
reverse.cpp
rnn_variable_seq_lens.cpp
rocblas.cpp
scatter.cpp
schedule_model.cpp
softmax.cpp
sync_device.cpp
target.cpp
topk.cpp
......@@ -207,81 +126,25 @@ function(register_migraphx_gpu_ops PREFIX)
endforeach()
endfunction()
register_migraphx_gpu_ops(hip_
acosh
acos
add
argmax
argmin
asinh
asin
atanh
atan
ceil
clip
concat
convert
cosh
cos
div
equal
erf
exp
floor
gather
greater
less
log
logsoftmax
logical_and
logical_or
logical_xor
loop
max
min
mul
multinomial
nonzero
pad
pow
prelu
prefix_scan_sum
recip
reduce_max
reduce_mean
reduce_min
reduce_prod
reduce_sum
relu
reverse
round
rsqrt
scatter
sigmoid
sign
sinh
sin
softmax
sqdiff
sqrt
sub
tanh
tan
topk
unary_not
where
)
register_migraphx_gpu_ops(miopen_
abs
batch_norm_inference
contiguous
convolution
deconvolution
elu
int8_conv_pack
leaky_relu
lrn
pooling
quant_convolution
)
register_op(migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
......@@ -295,6 +158,9 @@ register_op(migraphx_gpu
HEADER migraphx/gpu/gemm.hpp
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::deconvolution> gpu::miopen_convolution<op::quant_convolution>
INCLUDES migraphx/gpu/context.hpp)
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu)
......@@ -365,9 +231,21 @@ endif()
include(CheckLibraryExists)
get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# TODO: Set default to HAS_FIND_2_API
set(MIGRAPHX_USE_FIND_2_API OFF CACHE BOOL "")
if(MIGRAPHX_USE_FIND_2_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API)
message(STATUS "MIGraphx is using Find-2.0 API of MIOpen")
else()
message(STATUS "MIGraphx is using legacy Find API in MIOpen")
endif()
if(HAS_FIND_MODE_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_MODE_API)
message(STATUS "MIOpen has find mode api")
message(STATUS "MIGraphx is using Find Mode API of MIOpen")
else()
message(STATUS "MIOpen does not have find mode api")
endif()
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/clip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_clip::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument hip_clip::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::clip(ctx.get_stream().get(), args.back(), args.front(), args.at(1), args.at(2));
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -22,6 +22,7 @@
* THE SOFTWARE.
*/
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
......@@ -48,12 +49,13 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
return {4, 2};
}
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs)
vectorize vectorize::elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes)
{
if(std::all_of(
inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; }))
return {1, axis};
auto sizes = vector_sizes(inputs);
std::vector<std::size_t> max_vec_size;
std::transform(inputs.begin(),
inputs.end(),
......@@ -81,6 +83,33 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
return {*std::min_element(max_vec_size.begin(), max_vec_size.end()), axis};
}
vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs)
{
if(inputs.empty())
return {1, axis};
std::size_t n = std::max_element(inputs.begin(),
inputs.end(),
by(std::less<>{}, [](const auto& s) { return s.elements(); }))
->elements();
std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu();
std::size_t over = n / max_global;
bool broadcasted =
std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); });
std::vector<std::size_t> sizes;
if(broadcasted and over > 8)
sizes.push_back(8);
if(over > 4)
sizes.push_back(4);
sizes.push_back(2);
return elements(axis, inputs, sizes);
}
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs)
{
return elements(axis, inputs, vector_sizes(inputs));
}
std::string vectorize::str() const
{
return "vectorize<" + to_string(size) + ", " + to_string(axis) + ">()";
......@@ -102,7 +131,7 @@ preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs)
std::size_t bytes = 0;
for(auto i : preloaded)
{
auto input = inputs[i];
const auto& input = inputs[i];
bytes += input.bytes();
if(bytes > max_lds_bytes)
break;
......
......@@ -185,7 +185,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
options.push_back("-fno-gpu-rdc");
options.push_back(" -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3"));
options.push_back("-Wno-cuda-compat");
options.push_back("--cuda-gpu-arch=" + arch);
options.push_back("--offload-arch=" + arch);
prog.compile(options);
return {prog.get_code_obj()};
}
......@@ -237,7 +237,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
}
else if(is_hip_clang_compiler())
{
params += " --cuda-gpu-arch=" + arch;
params += " --offload-arch=" + arch;
params += " --cuda-device-only";
params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
}
......
......@@ -138,16 +138,15 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t groups = (n + local - 1) / local;
std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local;
return nglobal;
return std::min(nglobal, n);
};
}
std::size_t compute_block_size(std::size_t n, std::size_t max_block_size)
{
size_t block_size = 128;
while(block_size <= max_block_size and block_size <= n)
block_size *= 2;
return block_size / 2;
const std::size_t min_block_size = 64;
auto block_size = (((n - 1) / min_block_size + 1)) * min_block_size;
return std::min(std::max(min_block_size, block_size), max_block_size);
}
operation compile_hip_code_object(const std::string& content, hip_compile_options options)
......
......@@ -21,53 +21,81 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANT_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANT_CONVOLUTION_HPP
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/module.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/gpu/rocblas.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct miopen_quant_convolution
struct miopen_op
{
op::quant_convolution op;
bool int8_x4_format = false;
shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{};
uint64_t solution_id = 0;
operation op = op::identity{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
// TODO: Add algo
return pack_join(migraphx::reflect(self.op, f),
pack(f(self.int8_x4_format, "int8_x4_format")));
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::miopen_op"; }
shape compute_shape(std::vector<shape> inputs) const
{
inputs.push_back(inputs.back());
return op.compute_shape(inputs);
}
std::string name() const { return "gpu::quant_convolution"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
shape find(context& ctx, const shape& output_shape, std::vector<shape> inputs);
void finalize(context& ctx, const shape& output_shape, std::vector<shape> inputs);
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
private:
shape pack_int8_shape(const shape& s) const;
};
MIGRAPHX_REGISTER_OP(miopen_op);
std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool format) const
{
op.from_value({{"int8_x4_format", format}});
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(ins->inputs()));
return v.get<std::size_t>("workspace", 0);
}
void compile_miopen::apply(module& m) const
{
assert(ctx);
const bool int8_x4_format = get_int8_x4_format(any_cast<migraphx::gpu::context>(*ctx));
for(auto ins : iterator_for(m))
{
if(ins->name() != "gpu::miopen_op")
continue;
auto op = any_cast<miopen_op>(ins->get_operator()).op;
std::size_t ws = 0;
try
{
// for the regular convolution and deconvolution, this try would always succeed
ws = compile(op, ins, int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
ws = compile(op, ins, not int8_x4_format);
}
auto inputs = ins->inputs();
auto alloc = m.insert_instruction(
ins, make_op("allocate", {{"shape", to_value(shape{shape::int8_type, {ws}})}}));
inputs.insert(std::prev(inputs.end()), alloc);
m.replace_instruction(ins, op, inputs);
}
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment