Unverified Commit ebbaf8fc authored by turneram's avatar turneram Committed by GitHub
Browse files

Refactor where op (#918)

Implement the Where operator for the CPU and GPU.  This is for better performance.
parent 521b57a2
...@@ -172,6 +172,7 @@ register_migraphx_ops( ...@@ -172,6 +172,7 @@ register_migraphx_ops(
undefined undefined
unknown unknown
unsqueeze unsqueeze
where
) )
register_op(migraphx HEADER migraphx/op/rnn_variable_seq_lens.hpp OPERATORS op::rnn_var_sl_shift_output op::rnn_var_sl_shift_sequence) register_op(migraphx HEADER migraphx/op/rnn_variable_seq_lens.hpp OPERATORS op::rnn_var_sl_shift_output op::rnn_var_sl_shift_sequence)
register_op(migraphx HEADER migraphx/builtin.hpp OPERATORS builtin::literal builtin::param builtin::returns) register_op(migraphx HEADER migraphx/builtin.hpp OPERATORS builtin::literal builtin::param builtin::returns)
......
#ifndef MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#define MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct where
{
std::string name() const { return "where"; }
value attributes() const { return {{"pointwise", true}, {"point_op", "${0} ? ${1} : ${2}"}}; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).same_dims();
auto s1 = inputs.at(1);
auto s2 = inputs.at(2);
if(s1 == s2 and s1.packed())
{
return s1;
}
else if(s1.packed() != s2.packed())
{
return s1.packed() ? s1 : s2;
}
else if(s1.broadcasted() != s2.broadcasted())
{
return s1.broadcasted() ? s2.with_lens(s1.lens()) : s1.with_lens(s1.lens());
}
else
{
return {s1.type(), s1.lens()};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[1], args[2])([&](auto output, const auto x, const auto y) {
args[0].visit([&](const auto condition) {
par_for(output_shape.elements(),
[&](auto i) { output[i] = condition[i] ? x[i] : y[i]; });
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -102,5 +102,6 @@ ...@@ -102,5 +102,6 @@
#include <migraphx/op/undefined.hpp> #include <migraphx/op/undefined.hpp>
#include <migraphx/op/unknown.hpp> #include <migraphx/op/unknown.hpp>
#include <migraphx/op/unsqueeze.hpp> #include <migraphx/op/unsqueeze.hpp>
#include <migraphx/op/where.hpp>
#endif #endif
...@@ -17,13 +17,13 @@ struct parse_where : op_parser<parse_where> ...@@ -17,13 +17,13 @@ struct parse_where : op_parser<parse_where>
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto cond = auto lens =
info.add_instruction(make_op("convert", {{"target_type", shape::int32_type}}), args[0]); compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
auto lens = compute_broadcasted_lens(cond->get_shape().lens(), args[1]->get_shape().lens()); lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens()); if(args[0]->get_shape().lens() != lens)
if(cond->get_shape().lens() != lens)
{ {
cond = info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), cond); args[0] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
} }
if(args[1]->get_shape().lens() != lens) if(args[1]->get_shape().lens() != lens)
...@@ -38,24 +38,7 @@ struct parse_where : op_parser<parse_where> ...@@ -38,24 +38,7 @@ struct parse_where : op_parser<parse_where>
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
} }
// compute index return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
auto elem_num = args[1]->get_shape().elements();
// concatenation of input data
auto concat_data = info.add_instruction(make_op("concat", {{"axis", 0}}), args[2], args[1]);
std::vector<int64_t> dims = {static_cast<int64_t>(2 * elem_num)};
auto rsp_data = info.add_instruction(make_op("reshape", {{"dims", dims}}), concat_data);
std::vector<int> ind(elem_num);
std::iota(ind.begin(), ind.end(), 0);
shape ind_s{shape::int32_type, lens};
auto l_ind = info.add_literal(literal(ind_s, ind));
std::vector<int> offset(elem_num, elem_num);
auto l_offset = info.add_literal(literal({shape::int32_type, lens}, offset));
auto ins_offset = info.add_instruction(make_op("mul"), l_offset, cond);
auto ins_ind = info.add_instruction(make_op("add"), ins_offset, l_ind);
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp_data, ins_ind);
} }
}; };
......
...@@ -86,6 +86,7 @@ add_library(migraphx_device ...@@ -86,6 +86,7 @@ add_library(migraphx_device
device/tanh.cpp device/tanh.cpp
device/topk.cpp device/topk.cpp
device/unary_not.cpp device/unary_not.cpp
device/where.cpp
) )
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_set_soversion(migraphx_device ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_device ${MIGRAPHX_SO_VERSION})
...@@ -222,6 +223,7 @@ register_migraphx_gpu_ops(hip_ ...@@ -222,6 +223,7 @@ register_migraphx_gpu_ops(hip_
tan tan
topk topk
unary_not unary_not
where
) )
register_migraphx_gpu_ops(miopen_ register_migraphx_gpu_ops(miopen_
abs abs
......
#include <migraphx/gpu/device/where.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/launch.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class Shape>
constexpr auto get_rank(const Shape&)
{
return decltype(typename Shape::hip_index{}.size()){};
}
void where(hipStream_t stream,
const argument& result,
const argument& arg0,
const argument& arg1,
const argument& arg2)
{
hip_visit_all(result, arg1, arg2)([&](auto output, auto x, auto y) {
hip_visit_all(arg0)([&](auto cond) {
if constexpr(get_rank(cond.get_shape()) == get_rank(output.get_shape()))
{
gs_launch(stream, arg1.get_shape().elements())([=](auto idx) __device__ {
auto i = output.get_shape().multi(idx);
output[i] = cond[i] ? x[i] : y[i];
});
}
});
});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_WHERE_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_WHERE_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void where(hipStream_t stream,
const argument& result,
const argument& arg0,
const argument& arg1,
const argument& arg2);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_WHERE_HPP
#define MIGRAPHX_GUARD_RTGLIB_WHERE_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/where.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_where : ternary_device<hip_where, device::where>
{
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).same_dims();
auto s1 = inputs.at(1);
auto s2 = inputs.at(2);
if(s1 == s2 and s1.packed())
{
return s1;
}
else if(s1.packed() != s2.packed())
{
return s1.packed() ? s1 : s2;
}
else if(s1.broadcasted() != s2.broadcasted())
{
return s1.broadcasted() ? s2.with_lens(s1.lens()) : s1.with_lens(s1.lens());
}
else
{
return {s1.type(), s1.lens()};
}
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include <migraphx/gpu/quant_convolution.hpp> #include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/unary_not.hpp> #include <migraphx/gpu/unary_not.hpp>
#include <migraphx/gpu/where.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <utility> #include <utility>
...@@ -150,6 +151,7 @@ struct miopen_apply ...@@ -150,6 +151,7 @@ struct miopen_apply
add_generic_op("sub"); add_generic_op("sub");
add_generic_op("tan"); add_generic_op("tan");
add_generic_op("tanh"); add_generic_op("tanh");
add_generic_op("where");
add_extend_op("abs"); add_extend_op("abs");
add_extend_op("argmax"); add_extend_op("argmax");
......
...@@ -4041,32 +4041,14 @@ TEST_CASE(where_test) ...@@ -4041,32 +4041,14 @@ TEST_CASE(where_test)
auto lx = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); auto lx = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto ly = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 1, 2, 2}}); auto ly = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 1, 2, 2}});
auto int_c = mm->add_instruction( auto lccm =
migraphx::make_op("convert", mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), lc);
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
lc);
auto lccm = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), int_c);
auto lxm = auto lxm =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), lx); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), lx);
auto lym = auto lym =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), ly); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), ly);
auto concat_data = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), lym, lxm); auto r = mm->add_instruction(migraphx::make_op("where"), lccm, lxm, lym);
auto rsp_data =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {32}}}), concat_data);
std::vector<int> offset(16, 16);
std::vector<int> ind(16);
std::iota(ind.begin(), ind.end(), 0);
migraphx::shape ind_s{migraphx::shape::int32_type, {2, 2, 2, 2}};
auto lind = mm->add_literal(migraphx::literal(ind_s, ind));
auto loffset = mm->add_literal(migraphx::literal(ind_s, offset));
auto ins_co = mm->add_instruction(migraphx::make_op("mul"), loffset, lccm);
auto ins_ind = mm->add_instruction(migraphx::make_op("add"), ins_co, lind);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ins_ind);
mm->add_return({r}); mm->add_return({r});
auto prog = migraphx::parse_onnx("where_test.onnx"); auto prog = migraphx::parse_onnx("where_test.onnx");
......
...@@ -1598,4 +1598,12 @@ TEST_CASE(unary_broadcast_input) ...@@ -1598,4 +1598,12 @@ TEST_CASE(unary_broadcast_input)
expect_shape(s, migraphx::make_op("sin"), ss); expect_shape(s, migraphx::make_op("sin"), ss);
} }
TEST_CASE(where_broadcast_input)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 2}, {3, 0}};
migraphx::shape s2{migraphx::shape::float_type, {2, 2}};
migraphx::shape s3{migraphx::shape::bool_type, {2, 2}};
expect_shape(s2, migraphx::make_op("where"), s3, s1, s2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -4378,4 +4378,59 @@ TEST_CASE(unsqueeze_test) ...@@ -4378,4 +4378,59 @@ TEST_CASE(unsqueeze_test)
} }
} }
TEST_CASE(where_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {3, 3}};
migraphx::shape sx{migraphx::shape::float_type, {3, 3}};
std::vector<bool> b{true, true, true, false, false, false, true, false, true};
std::vector<float> x(9, 1.0);
std::vector<float> y(9, 2.0);
auto lb = mm->add_literal(migraphx::literal{sb, b});
auto lx = mm->add_literal(migraphx::literal{sx, x});
auto ly = mm->add_literal(migraphx::literal{sx, y});
auto w = mm->add_instruction(migraphx::make_op("where"), lb, lx, ly);
mm->add_return({w});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<float> gold(9);
for(int i = 0; i < gold.size(); ++i)
gold[i] = b[i] ? x[i] : y[i];
EXPECT(migraphx::verify_range(result_vec, gold));
}
TEST_CASE(where_broadcasted_inputs_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {3, 3}};
std::vector<bool> b{true, true, true, false, false, false, true, false, true};
auto lb = mm->add_literal(migraphx::literal{sb, b});
auto lx = mm->add_literal(migraphx::literal(1.0f));
auto ly = mm->add_literal(migraphx::literal(2.0f));
auto mbx = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), lx);
auto mby = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), ly);
auto w = mm->add_instruction(migraphx::make_op("where"), lb, mbx, mby);
mm->add_return({w});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<float> gold(9);
std::vector<float> x(9, 1.0);
std::vector<float> y(9, 2.0);
for(int i = 0; i < gold.size(); ++i)
gold[i] = b[i] ? x[i] : y[i];
EXPECT(migraphx::verify_range(result_vec, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_where : verify_program<test_where>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {1, 3, 4, 5}};
migraphx::shape sx{migraphx::shape::float_type, {1, 3, 4, 5}};
auto b = mm->add_parameter("b", sb);
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sx);
auto r = mm->add_instruction(migraphx::make_op("where"), b, x, y);
mm->add_return({r});
return p;
};
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_where2 : verify_program<test_where2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {1, 3, 4, 5}};
migraphx::shape sx{migraphx::shape::float_type, {1}};
auto b = mm->add_parameter("b", sb);
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sx);
auto mbx = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 3, 4, 5}}}), x);
auto mby = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 3, 4, 5}}}), y);
auto r = mm->add_instruction(migraphx::make_op("where"), b, mbx, mby);
mm->add_return({r});
return p;
};
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment