Unverified Commit ebbaf8fc authored by turneram's avatar turneram Committed by GitHub
Browse files

Refactor where op (#918)

Implement the Where operator for the CPU and GPU.  This is for better performance.
parent 521b57a2
......@@ -172,6 +172,7 @@ register_migraphx_ops(
undefined
unknown
unsqueeze
where
)
register_op(migraphx HEADER migraphx/op/rnn_variable_seq_lens.hpp OPERATORS op::rnn_var_sl_shift_output op::rnn_var_sl_shift_sequence)
register_op(migraphx HEADER migraphx/builtin.hpp OPERATORS builtin::literal builtin::param builtin::returns)
......
#ifndef MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#define MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct where
{
std::string name() const { return "where"; }
value attributes() const { return {{"pointwise", true}, {"point_op", "${0} ? ${1} : ${2}"}}; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).same_dims();
auto s1 = inputs.at(1);
auto s2 = inputs.at(2);
if(s1 == s2 and s1.packed())
{
return s1;
}
else if(s1.packed() != s2.packed())
{
return s1.packed() ? s1 : s2;
}
else if(s1.broadcasted() != s2.broadcasted())
{
return s1.broadcasted() ? s2.with_lens(s1.lens()) : s1.with_lens(s1.lens());
}
else
{
return {s1.type(), s1.lens()};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[1], args[2])([&](auto output, const auto x, const auto y) {
args[0].visit([&](const auto condition) {
par_for(output_shape.elements(),
[&](auto i) { output[i] = condition[i] ? x[i] : y[i]; });
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -102,5 +102,6 @@
#include <migraphx/op/undefined.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unsqueeze.hpp>
#include <migraphx/op/where.hpp>
#endif
......@@ -17,13 +17,13 @@ struct parse_where : op_parser<parse_where>
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto cond =
info.add_instruction(make_op("convert", {{"target_type", shape::int32_type}}), args[0]);
auto lens = compute_broadcasted_lens(cond->get_shape().lens(), args[1]->get_shape().lens());
auto lens =
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(cond->get_shape().lens() != lens)
if(args[0]->get_shape().lens() != lens)
{
cond = info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), cond);
args[0] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
}
if(args[1]->get_shape().lens() != lens)
......@@ -38,24 +38,7 @@ struct parse_where : op_parser<parse_where>
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
}
// compute index
auto elem_num = args[1]->get_shape().elements();
// concatenation of input data
auto concat_data = info.add_instruction(make_op("concat", {{"axis", 0}}), args[2], args[1]);
std::vector<int64_t> dims = {static_cast<int64_t>(2 * elem_num)};
auto rsp_data = info.add_instruction(make_op("reshape", {{"dims", dims}}), concat_data);
std::vector<int> ind(elem_num);
std::iota(ind.begin(), ind.end(), 0);
shape ind_s{shape::int32_type, lens};
auto l_ind = info.add_literal(literal(ind_s, ind));
std::vector<int> offset(elem_num, elem_num);
auto l_offset = info.add_literal(literal({shape::int32_type, lens}, offset));
auto ins_offset = info.add_instruction(make_op("mul"), l_offset, cond);
auto ins_ind = info.add_instruction(make_op("add"), ins_offset, l_ind);
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp_data, ins_ind);
return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
}
};
......
......@@ -86,6 +86,7 @@ add_library(migraphx_device
device/tanh.cpp
device/topk.cpp
device/unary_not.cpp
device/where.cpp
)
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_set_soversion(migraphx_device ${MIGRAPHX_SO_VERSION})
......@@ -222,6 +223,7 @@ register_migraphx_gpu_ops(hip_
tan
topk
unary_not
where
)
register_migraphx_gpu_ops(miopen_
abs
......
#include <migraphx/gpu/device/where.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/launch.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class Shape>
constexpr auto get_rank(const Shape&)
{
return decltype(typename Shape::hip_index{}.size()){};
}
void where(hipStream_t stream,
const argument& result,
const argument& arg0,
const argument& arg1,
const argument& arg2)
{
hip_visit_all(result, arg1, arg2)([&](auto output, auto x, auto y) {
hip_visit_all(arg0)([&](auto cond) {
if constexpr(get_rank(cond.get_shape()) == get_rank(output.get_shape()))
{
gs_launch(stream, arg1.get_shape().elements())([=](auto idx) __device__ {
auto i = output.get_shape().multi(idx);
output[i] = cond[i] ? x[i] : y[i];
});
}
});
});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_WHERE_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_WHERE_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void where(hipStream_t stream,
const argument& result,
const argument& arg0,
const argument& arg1,
const argument& arg2);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_WHERE_HPP
#define MIGRAPHX_GUARD_RTGLIB_WHERE_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/where.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_where : ternary_device<hip_where, device::where>
{
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).same_dims();
auto s1 = inputs.at(1);
auto s2 = inputs.at(2);
if(s1 == s2 and s1.packed())
{
return s1;
}
else if(s1.packed() != s2.packed())
{
return s1.packed() ? s1 : s2;
}
else if(s1.broadcasted() != s2.broadcasted())
{
return s1.broadcasted() ? s2.with_lens(s1.lens()) : s1.with_lens(s1.lens());
}
else
{
return {s1.type(), s1.lens()};
}
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -38,6 +38,7 @@
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/unary_not.hpp>
#include <migraphx/gpu/where.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <utility>
......@@ -150,6 +151,7 @@ struct miopen_apply
add_generic_op("sub");
add_generic_op("tan");
add_generic_op("tanh");
add_generic_op("where");
add_extend_op("abs");
add_extend_op("argmax");
......
......@@ -4041,32 +4041,14 @@ TEST_CASE(where_test)
auto lx = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto ly = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 1, 2, 2}});
auto int_c = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
lc);
auto lccm = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), int_c);
auto lccm =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), lc);
auto lxm =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), lx);
auto lym =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), ly);
auto concat_data = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), lym, lxm);
auto rsp_data =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {32}}}), concat_data);
std::vector<int> offset(16, 16);
std::vector<int> ind(16);
std::iota(ind.begin(), ind.end(), 0);
migraphx::shape ind_s{migraphx::shape::int32_type, {2, 2, 2, 2}};
auto lind = mm->add_literal(migraphx::literal(ind_s, ind));
auto loffset = mm->add_literal(migraphx::literal(ind_s, offset));
auto ins_co = mm->add_instruction(migraphx::make_op("mul"), loffset, lccm);
auto ins_ind = mm->add_instruction(migraphx::make_op("add"), ins_co, lind);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ins_ind);
auto r = mm->add_instruction(migraphx::make_op("where"), lccm, lxm, lym);
mm->add_return({r});
auto prog = migraphx::parse_onnx("where_test.onnx");
......
......@@ -1598,4 +1598,12 @@ TEST_CASE(unary_broadcast_input)
expect_shape(s, migraphx::make_op("sin"), ss);
}
TEST_CASE(where_broadcast_input)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 2}, {3, 0}};
migraphx::shape s2{migraphx::shape::float_type, {2, 2}};
migraphx::shape s3{migraphx::shape::bool_type, {2, 2}};
expect_shape(s2, migraphx::make_op("where"), s3, s1, s2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -4378,4 +4378,59 @@ TEST_CASE(unsqueeze_test)
}
}
TEST_CASE(where_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {3, 3}};
migraphx::shape sx{migraphx::shape::float_type, {3, 3}};
std::vector<bool> b{true, true, true, false, false, false, true, false, true};
std::vector<float> x(9, 1.0);
std::vector<float> y(9, 2.0);
auto lb = mm->add_literal(migraphx::literal{sb, b});
auto lx = mm->add_literal(migraphx::literal{sx, x});
auto ly = mm->add_literal(migraphx::literal{sx, y});
auto w = mm->add_instruction(migraphx::make_op("where"), lb, lx, ly);
mm->add_return({w});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<float> gold(9);
for(int i = 0; i < gold.size(); ++i)
gold[i] = b[i] ? x[i] : y[i];
EXPECT(migraphx::verify_range(result_vec, gold));
}
TEST_CASE(where_broadcasted_inputs_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {3, 3}};
std::vector<bool> b{true, true, true, false, false, false, true, false, true};
auto lb = mm->add_literal(migraphx::literal{sb, b});
auto lx = mm->add_literal(migraphx::literal(1.0f));
auto ly = mm->add_literal(migraphx::literal(2.0f));
auto mbx = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), lx);
auto mby = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), ly);
auto w = mm->add_instruction(migraphx::make_op("where"), lb, mbx, mby);
mm->add_return({w});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<float> gold(9);
std::vector<float> x(9, 1.0);
std::vector<float> y(9, 2.0);
for(int i = 0; i < gold.size(); ++i)
gold[i] = b[i] ? x[i] : y[i];
EXPECT(migraphx::verify_range(result_vec, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_where : verify_program<test_where>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {1, 3, 4, 5}};
migraphx::shape sx{migraphx::shape::float_type, {1, 3, 4, 5}};
auto b = mm->add_parameter("b", sb);
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sx);
auto r = mm->add_instruction(migraphx::make_op("where"), b, x, y);
mm->add_return({r});
return p;
};
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_where2 : verify_program<test_where2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {1, 3, 4, 5}};
migraphx::shape sx{migraphx::shape::float_type, {1}};
auto b = mm->add_parameter("b", sb);
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sx);
auto mbx = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 3, 4, 5}}}), x);
auto mby = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 3, 4, 5}}}), y);
auto r = mm->add_instruction(migraphx::make_op("where"), b, mbx, mby);
mm->add_return({r});
return p;
};
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment