Commit c1ec929c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents abe2a889 03225b57
...@@ -133,6 +133,12 @@ std::string to_json_string(const value& val) ...@@ -133,6 +133,12 @@ std::string to_json_string(const value& val)
return j.dump(); return j.dump();
} }
std::string to_pretty_json_string(const value& val, std::size_t indent)
{
json j = val;
return j.dump(indent);
}
migraphx::value from_json_string(const char* str, std::size_t size) migraphx::value from_json_string(const char* str, std::size_t size)
{ {
json j = json::parse(str, str + size); json j = json::parse(str, str + size);
......
...@@ -32,9 +32,20 @@ struct onnx_parser ...@@ -32,9 +32,20 @@ struct onnx_parser
instruction_ref add_bias(const std::vector<instruction_ref>& args, instruction_ref add_bias(const std::vector<instruction_ref>& args,
instruction_ref curr_ins, instruction_ref curr_ins,
uint64_t axis) const; uint64_t axis) const;
instruction_ref add_broadcastable_binary_op(const std::string& op_name, instruction_ref add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0, instruction_ref arg0,
instruction_ref arg1) const; instruction_ref arg1) const;
instruction_ref add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const;
template <class... Ts>
instruction_ref add_common_op(const std::string& op_name, Ts... xs) const
{
return add_common_op(op_name, {xs...});
}
instruction_ref add_instruction(const operation& op, instruction_ref add_instruction(const operation& op,
const std::vector<instruction_ref>& args) const; const std::vector<instruction_ref>& args) const;
......
...@@ -98,7 +98,13 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s ...@@ -98,7 +98,13 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s
instruction_ref arg0, instruction_ref arg0,
instruction_ref arg1) const instruction_ref arg1) const
{ {
return add_common_op(*mod, make_op(op_name), {arg0, arg1}); return this->add_common_op(op_name, arg0, arg1);
}
instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const
{
return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs));
} }
instruction_ref instruction_ref
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/pad_calc.hpp> #include <migraphx/pad_calc.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -94,7 +95,7 @@ void tune_padding_size(const value& v, ...@@ -94,7 +95,7 @@ void tune_padding_size(const value& v,
std::vector<int64_t>& s_start) std::vector<int64_t>& s_start)
{ {
// maxpooling or count_include_pad is 1, no change is required. // maxpooling or count_include_pad is 1, no change is required.
if(v.at("mode").to<std::string>() == "max" or count_include_pad == 1) if(v.at("mode").to<op::pooling_mode>() == op::pooling_mode::max or count_include_pad == 1)
{ {
return; return;
} }
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_celu : op_parser<parse_celu>
{
std::vector<op_desc> operators() const { return {{"Celu"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 1.0;
if(contains(info.attributes, "alpha"))
{
alpha = info.attributes.at("alpha").f();
}
if(float_equal(alpha, 0.0f))
{
MIGRAPHX_THROW("CELU: alpha is zero (division by zero)");
}
auto input_lens = args[0]->get_shape().lens();
auto input_type = args[0]->get_shape().type();
if(input_type != migraphx::shape::float_type)
{
MIGRAPHX_THROW("CELU: input tensor not float type");
}
auto zero_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}}));
auto one_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}}));
auto alpha_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}}));
auto linear_part = info.add_instruction(migraphx::make_op("max"), zero_lit, args[0]);
auto divi = info.add_instruction(migraphx::make_op("div"), args[0], alpha_lit);
auto expo = info.add_instruction(migraphx::make_op("exp"), divi);
auto sub = info.add_instruction(migraphx::make_op("sub"), expo, one_lit);
auto mul = info.add_instruction(migraphx::make_op("mul"), alpha_lit, sub);
auto exp_part = info.add_instruction(migraphx::make_op("min"), zero_lit, mul);
return info.add_instruction(migraphx::make_op("add"), linear_part, exp_part);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -16,7 +16,6 @@ struct parse_clip : op_parser<parse_clip> ...@@ -16,7 +16,6 @@ struct parse_clip : op_parser<parse_clip>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto input_lens = args[0]->get_shape().lens();
instruction_ref min_arg; instruction_ref min_arg;
instruction_ref max_arg; instruction_ref max_arg;
bool min_used = false; bool min_used = false;
...@@ -45,29 +44,17 @@ struct parse_clip : op_parser<parse_clip> ...@@ -45,29 +44,17 @@ struct parse_clip : op_parser<parse_clip>
max_used = true; max_used = true;
} }
if(min_used)
{
min_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
min_arg);
}
if(max_used)
{
max_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
max_arg);
}
if(min_used and max_used) if(min_used and max_used)
{ {
return info.add_instruction(make_op("clip"), args[0], min_arg, max_arg); return info.add_common_op("clip", args[0], min_arg, max_arg);
} }
else if(max_used) else if(max_used)
{ {
return info.add_instruction(make_op("min"), args[0], max_arg); return info.add_broadcastable_binary_op("min", args[0], max_arg);
} }
else if(min_used) else if(min_used)
{ {
return info.add_instruction(make_op("max"), args[0], min_arg); return info.add_broadcastable_binary_op("max", args[0], min_arg);
} }
else else
{ {
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_eyelike : op_parser<parse_eyelike>
{
std::vector<op_desc> operators() const { return {{"EyeLike"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto input_shape = args[0]->get_shape();
auto input_lens = input_shape.lens();
if(input_lens.size() != 2)
{
MIGRAPHX_THROW("EYELIKE: tensor input not of rank 2");
}
std::ptrdiff_t num_rows = input_lens.front();
std::ptrdiff_t num_cols = input_lens.back();
shape::type_t output_type = args[0]->get_shape().type();
if(contains(info.attributes, "dtype"))
{
output_type = get_type(info.attributes.at("dtype").i());
}
std::ptrdiff_t k = 0;
if(contains(info.attributes, "k"))
{
k = info.attributes.at("k").i();
}
if(k >= 0)
{
if(k >= num_cols)
{
std::ostringstream oss;
oss << "EYELIKE: positive k out of bounds, k = " << k << " num_cols = " << num_cols;
MIGRAPHX_THROW(oss.str());
}
}
else
{
if(std::abs(k) >= num_rows)
{
std::ostringstream oss;
oss << "EYELIKE: negative k out of bounds, k = " << k << " num_rows = " << num_cols;
MIGRAPHX_THROW(oss.str());
}
}
std::vector<char> eyelike_mat(num_rows * num_cols, 0);
for(std::ptrdiff_t i = 0; i < num_rows; ++i)
{
auto idx = i + k;
if(idx < num_cols and idx >= 0)
eyelike_mat[(num_cols + 1) * i + k] = char{1};
}
return info.add_literal(
migraphx::literal{migraphx::shape{output_type, input_lens}, eyelike_mat});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
//! Parser for LpNormalization ONNX operator.
/*!
Normalizes a tensor by the L1 or L2 norms along a given axis.
Norms that evaluate to 0 are changed to 1 to prevent division by zero.
*/
struct parse_lpnormalization : op_parser<parse_lpnormalization>
{
std::vector<op_desc> operators() const { return {{"LpNormalization"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int p = 2;
if(contains(info.attributes, "p"))
{
p = info.attributes.at("p").i();
}
if(p != 1 and p != 2)
{
MIGRAPHX_THROW("LPNORMALIZATION: only L1 and L2 norm supported");
}
auto input = args.front();
auto input_shape = input->get_shape();
const auto& input_lens = input_shape.lens();
auto input_type = input_shape.type();
std::ptrdiff_t num_axes = input_lens.size();
std::ptrdiff_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = info.attributes.at("axis").i();
if(axis < -num_axes or axis >= num_axes)
{
// handled in normalize_attributes but throwing here might be clearer
MIGRAPHX_THROW("LPNORMALIZATION: selected axis out of bounds");
}
}
migraphx::instruction_ref p_val;
if(p == 1)
{
p_val = info.add_instruction(migraphx::make_op("abs"), input);
}
else
{
p_val = info.add_instruction(migraphx::make_op("mul"), input, input);
}
// need to check for zeros from lp norm to prevent division by zero
// change them to 1 for the element-wise division
auto norms =
info.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {axis}}}), p_val);
if(p == 2)
{
norms = info.add_instruction(migraphx::make_op("sqrt"), norms);
}
// broadcast back to initial shape, negative axis option doesn't work with unidirectional
norms = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), norms);
auto zero_mb = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}}));
auto one_mb = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}}));
auto is_zero = info.add_instruction(migraphx::make_op("equal"), norms, zero_mb);
auto norms_zeros_to_one =
info.add_instruction(migraphx::make_op("where"), is_zero, one_mb, norms);
return info.add_instruction(migraphx::make_op("div"), input, norms_zeros_to_one);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/padding.hpp> #include <migraphx/onnx/padding.hpp>
#include <migraphx/op/pad.hpp> #include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
...@@ -27,10 +28,16 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -27,10 +28,16 @@ struct parse_pooling : op_parser<parse_pooling>
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
std::string mode = opd.op_name; std::string mode = opd.op_name;
operation op = make_op("pooling", {{"mode", mode}}); if(mode != "max" && mode != "average")
value values = op.to_value(); {
auto l0 = args[0]; MIGRAPHX_THROW("onnx pooling mode must be \"max\" or \"average\"");
auto in_lens = l0->get_shape().lens(); }
operation op = make_op(
"pooling",
{{"mode", mode == "average" ? op::pooling_mode::average : op::pooling_mode::max}});
value values = op.to_value();
auto l0 = args[0];
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2); assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2; auto kdims = in_lens.size() - 2;
...@@ -72,6 +79,7 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -72,6 +79,7 @@ struct parse_pooling : op_parser<parse_pooling>
std::vector<int64_t> paddings; std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f); float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads")) if(contains(info.attributes, "pads"))
{ {
values["padding"].clear(); values["padding"].clear();
......
#include <migraphx/op/common.hpp>
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
...@@ -28,10 +29,14 @@ struct parse_roialign : op_parser<parse_roialign> ...@@ -28,10 +29,14 @@ struct parse_roialign : op_parser<parse_roialign>
"\": invalid value!"); "\": invalid value!");
} }
std::string mode = "avg"; migraphx::op::pooling_mode rmode(migraphx::op::pooling_mode::average);
if(contains(info.attributes, "mode")) if(contains(info.attributes, "mode"))
{ {
mode = info.attributes.at("mode").s(); // read mode; default is "avg"
if(info.attributes.at("mode").s() == "max")
{
rmode = migraphx::op::pooling_mode::max;
}
} }
int64_t output_height = 1; int64_t output_height = 1;
...@@ -57,10 +62,9 @@ struct parse_roialign : op_parser<parse_roialign> ...@@ -57,10 +62,9 @@ struct parse_roialign : op_parser<parse_roialign>
{ {
spatial_scale = info.attributes.at("spatial_scale").f(); spatial_scale = info.attributes.at("spatial_scale").f();
} }
return info.add_instruction(make_op("roialign", return info.add_instruction(make_op("roialign",
{{"coordinate_transformation_mode", coord_trans_mode}, {{"coordinate_transformation_mode", coord_trans_mode},
{"mode", mode}, {"mode", rmode},
{"output_height", output_height}, {"output_height", output_height},
{"output_width", output_width}, {"output_width", output_width},
{"sampling_ratio", sampling_ratio}, {"sampling_ratio", sampling_ratio},
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_size : op_parser<parse_size>
{
std::vector<op_desc> operators() const { return {{"Size"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
return info.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type},
{args[0]->get_shape().elements()}});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
//
// Supporting functions for enum values used in operator parameters.
// These values are declared as "enum class" and should include << streaming operators
// to be able to write their values in human-readable format so users can
// save and edit model files.
//
#include <sstream>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
std::ostream& operator<<(std::ostream& os, pooling_mode v)
{
// the strings for the enum are the same as the values used for onnx parsing
// but this enum is not onnx-specific: strings must be converted when parsing tf
static const std::vector<std::string> pooling_mode_str = {"average", "max"};
os << pooling_mode_str[static_cast<std::underlying_type<pooling_mode>::type>(v)];
return os;
}
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
static const std::vector<std::string> rnn_direction_str = {
"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -353,17 +353,20 @@ std::vector<argument> program::eval(parameter_map params) const ...@@ -353,17 +353,20 @@ std::vector<argument> program::eval(parameter_map params) const
if(trace_level > 0) if(trace_level > 0)
{ {
std::unordered_map<instruction_ref, std::string> ins_names; std::unordered_map<instruction_ref, std::string> ins_out;
// get instruction names // get instruction names
this->print(ins_names, [](auto, auto) {}); this->print([&](auto x, auto ins_names) {
std::stringstream ss;
instruction::print(ss, x, ins_names);
ins_out[x] = ss.str();
});
return generic_eval(*this, return generic_eval(*this,
ctx, ctx,
std::move(params), std::move(params),
with_check_context([&](auto& ins, auto f, auto&& check_context) { with_check_context([&](auto& ins, auto f, auto&& check_context) {
ctx.finish(); ctx.finish();
std::cout << "Run instruction: "; std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
this->debug_print(ins, ins_names);
std::cout << std::endl;
timer t{}; timer t{};
auto result = check_context(f); auto result = check_context(f);
double t1 = t.record<milliseconds>(); double t1 = t.record<milliseconds>();
...@@ -742,6 +745,14 @@ void program::print( ...@@ -742,6 +745,14 @@ void program::print(
} }
} }
void program::print(
const std::function<void(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>)>& print_func) const
{
std::unordered_map<instruction_ref, std::string> names;
this->print(names, print_func);
}
void program::print_graph(std::ostream& os, bool brief) const void program::print_graph(std::ostream& os, bool brief) const
{ {
const auto* mm = this->get_main_module(); const auto* mm = this->get_main_module();
......
...@@ -211,12 +211,21 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -211,12 +211,21 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE(migraphx, m) MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{ {
py::class_<migraphx::shape>(m, "shape") py::class_<migraphx::shape>(m, "shape")
.def(py::init<>()) .def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", std::string{"float"}));
auto lens = v.get<std::size_t>("lens", {1});
if(v.contains("strides"))
return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>());
else
return migraphx::shape(t, lens);
}))
.def("type", &migraphx::shape::type) .def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens) .def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides) .def("strides", &migraphx::shape::strides)
.def("elements", &migraphx::shape::elements) .def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes) .def("bytes", &migraphx::shape::bytes)
.def("type_string", &migraphx::shape::type_string)
.def("type_size", &migraphx::shape::type_size) .def("type_size", &migraphx::shape::type_size)
.def("packed", &migraphx::shape::packed) .def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed) .def("transposed", &migraphx::shape::transposed)
......
...@@ -38,7 +38,7 @@ void rewrite_pooling::apply(module& prog) const ...@@ -38,7 +38,7 @@ void rewrite_pooling::apply(module& prog) const
instruction_ref pooling{}; instruction_ref pooling{};
// average pooling // average pooling
if(op.mode == "average") if(op.mode == op::pooling_mode::average)
{ {
pooling = pooling =
prog.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape); prog.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
......
...@@ -1426,14 +1426,5 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog, ...@@ -1426,14 +1426,5 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
return hs_padded; return hs_padded;
} }
namespace op {
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -460,10 +460,10 @@ struct cpu_apply ...@@ -460,10 +460,10 @@ struct cpu_apply
if(has_op("dnnl::pooling") and ins->get_shape().type() == shape::type_t::float_type and if(has_op("dnnl::pooling") and ins->get_shape().type() == shape::type_t::float_type and
not v["ceil_mode"].to<bool>()) not v["ceil_mode"].to<bool>())
return replace(ins, make_op("dnnl::pooling", op.to_value())); return replace(ins, make_op("dnnl::pooling", op.to_value()));
std::string mode = v["mode"].to<std::string>(); op::pooling_mode mode = v["mode"].to<op::pooling_mode>();
if(mode == "max") if(mode == op::pooling_mode::max)
return replace(ins, make_op("cpu::pooling_max", v)); return replace(ins, make_op("cpu::pooling_max", v));
else if(mode == "average") else if(mode == op::pooling_mode::average)
return replace(ins, make_op("cpu::pooling_average", v)); return replace(ins, make_op("cpu::pooling_average", v));
return ins; return ins;
} }
......
...@@ -129,7 +129,8 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po ...@@ -129,7 +129,8 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{ {
auto algo = op.mode == "max" ? dnnl::algorithm::pooling_max : dnnl::algorithm::pooling_avg; auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max
: dnnl::algorithm::pooling_avg;
auto kdims = op.kdims(); auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims); std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end()); std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
...@@ -145,5 +146,6 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po ...@@ -145,5 +146,6 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
}; };
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -240,8 +240,9 @@ std::string enum_params(std::size_t count, std::string param) ...@@ -240,8 +240,9 @@ std::string enum_params(std::size_t count, std::string param)
std::size_t compute_global(std::size_t n, std::size_t local) std::size_t compute_global(std::size_t n, std::size_t local)
{ {
std::size_t groups = (n + local - 1) / local; std::size_t groups = (n + local - 1) / local;
std::size_t nglobal = std::min<std::size_t>(256, groups) * local; // max possible number of blocks is set to 1B (1,073,741,824)
std::size_t nglobal = std::min<std::size_t>(1073741824, groups) * local;
return nglobal; return nglobal;
} }
......
...@@ -59,8 +59,8 @@ operation compile_roialign(context&, const std::vector<shape>& io_shapes, const ...@@ -59,8 +59,8 @@ operation compile_roialign(context&, const std::vector<shape>& io_shapes, const
// pooling_mode // pooling_mode
assert(val.contains("mode")); assert(val.contains("mode"));
auto mode = val.at("mode").to<std::string>(); auto mode = val.at("mode").to<migraphx::op::pooling_mode>();
bool is_avg_pooling = (mode == "avg"); bool is_avg_pooling = (mode == migraphx::op::pooling_mode::average);
options.params += " -DIS_AVG_POOLING=" + std::to_string(static_cast<int>(is_avg_pooling)); options.params += " -DIS_AVG_POOLING=" + std::to_string(static_cast<int>(is_avg_pooling));
// coord_trans_mode // coord_trans_mode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment