Commit 442581b9 authored by Paul's avatar Paul
Browse files

Try to optimize broadcast add

parent 6ce07611
......@@ -16,8 +16,9 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = pack(
std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()}, inputs.data())...);
auto data =
pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()},
inputs.data())...);
hip_tensor_descriptor<ndim> out_desc(output_shape);
auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) {
......@@ -36,6 +37,34 @@ auto nary_nonstandard(argument result, Arguments... args)
return [=](auto f) { return nary_nonstandard_impl(f, result, args...); };
}
inline auto binary_broadcast(argument result, argument arg1, argument arg2)
{
return [=](auto f) {
const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape();
auto bdim = std::distance(b_shape.strides().begin(), std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) { return x != 0; }));
auto bdim_len = b_shape.lens()[bdim];
auto outer_size = std::accumulate(output_shape.lens().begin(), output_shape.lens().begin() + bdim + 1, std::size_t{1}, std::multiplies<>{});
auto inner_size = std::accumulate(output_shape.lens().begin()+bdim+1, output_shape.lens().end(), std::size_t{1}, std::multiplies<>{});
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
auto* xp = input1.data();
auto* yp = input2.data();
auto* outp = output.data();
gs_launch(outer_size)(
[=](auto i) {
auto * outp2 = outp + i;
auto * xp2 = xp + i;
auto b = yp[i % bdim_len];
for(std::size_t j = 0;j < inner_size;j++)
{
outp2[j] = f(xp2[j], b);
}
});
});
};
}
template <class... Arguments>
auto nary_standard(argument result, Arguments... args)
{
......@@ -52,13 +81,12 @@ auto nary_standard(argument result, Arguments... args)
}
template <class... Arguments>
auto nary(argument result, Arguments... args)
auto nary_impl(argument result, Arguments... args)
{
return [=](auto f) {
bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); });
bool same_shapes =
all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); });
bool same_shapes = all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
if(standard or (packed and same_shapes))
nary_standard(result, args...)(f);
else
......@@ -67,6 +95,29 @@ auto nary(argument result, Arguments... args)
};
}
template <class... Arguments>
auto nary(argument result, Arguments... args)
{
return nary_impl(result, args...);
}
#if 0
inline auto nary(argument result, argument arg1, argument arg2)
{
return [=](auto f) {
// TODO: Check for one broadcast stride
// TODO: Check result and arg1 shape is the same
if(arg1.get_shape().standard() and arg2.get_shape().broadcasted() and std::count_if(arg2.get_shape().strides().begin(), arg2.get_shape().strides().end(), [](auto x) { return x != 0; }) == 1)
{
binary_broadcast(result, arg1, arg2)(f);
}
else
{
nary_impl(result, arg1, arg2)(f);
}
};
}
#endif
} // namespace device
} // namespace gpu
} // namespace migraph
......
......@@ -210,6 +210,20 @@ struct test_add_broadcast
}
};
struct test_add_broadcast2
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 3, 4}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {3}});
auto by = p.add_instruction(migraph::broadcast{1}, x, y);
p.add_instruction(migraph::add{}, x, by);
return p;
}
};
struct test_conv_relu
{
migraph::program create_program() const
......@@ -418,6 +432,7 @@ int main()
{
verify_program<test_add>();
verify_program<test_add_broadcast>();
verify_program<test_add_broadcast2>();
verify_program<test_conv_relu>();
verify_program<test_add_relu>();
verify_program<test_conv_pooling>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment