Commit a919e88a authored by Paul's avatar Paul
Browse files

Add nary to handle any number of arguments

parent 032d0650
......@@ -34,6 +34,14 @@ auto fix(F f)
return fix<void>(f);
}
template<class... Ts>
auto make_sequence(Ts... xs)
{
return [=](auto f) {
return f(xs...);
};
}
} // namespace migraph
#endif
......@@ -60,6 +60,18 @@ bool contains(const std::initializer_list<T>& c, const U& x)
return generic_find(c, x) != c.end();
}
template <class C, class Predicate>
bool all_of(const C& c, const Predicate& p)
{
return std::all_of(c.begin(), c.end(), p);
}
template <class T, class Predicate>
bool all_of(const std::initializer_list<T>& c, const Predicate& p)
{
return std::all_of(c.begin(), c.end(), p);
}
template <class Range, class Iterator>
void copy(Range&& r, Iterator it)
{
......
#include <migraph/gpu/device/contiguous.hpp>
#include <migraph/gpu/device/binary.hpp>
#include <migraph/gpu/device/add_relu.hpp>
#include <migraph/gpu/device/nary.hpp>
namespace migraph {
namespace gpu {
namespace device {
void add_relu(argument arg1, argument arg2, argument result)
void add_relu(argument result, argument arg1, argument arg2)
{
binary_standard(arg1, arg2, result, [](auto x, auto y) { return max(0, x + y); });
nary_standard(result, arg1, arg2)([](auto x, auto y) { return max(0, x + y); });
}
} // namespace device
......
#include <migraph/gpu/device/contiguous.hpp>
#include <migraph/gpu/device/unary.hpp>
#include <migraph/gpu/device/nary.hpp>
namespace migraph {
namespace gpu {
namespace device {
void contiguous(argument arg, argument result)
void contiguous(argument result, argument arg)
{
unary_nonstandard(arg, result, [](auto x) { return x; });
nary_nonstandard(result, arg)([](auto x) { return x; });
}
} // namespace device
......
#ifndef MIGRAPH_GUARD_RTGLIB_DEVICE_NARY_HPP
#define MIGRAPH_GUARD_RTGLIB_DEVICE_NARY_HPP
#include <migraph/gpu/device/tensor.hpp>
#include <migraph/gpu/device/launch.hpp>
#include <migraph/functional.hpp>
#include <migraph/ranges.hpp>
namespace migraph {
namespace gpu {
namespace device {
template <class... Arguments>
auto nary(argument result, Arguments... args)
{
return [=](auto f) {
if(all_of({args...}, [](const shape& s) { return s.standard(); }))
nary_standard(result, args...)(f);
else
nary_nonstandard(result, args...)(f);
};
}
template <class... Arguments>
auto nary_nonstandard(argument result, Arguments... args)
{
return [=](auto f) {
auto output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = make_sequence(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape().lens(), inputs.get_shape().strides()}, inputs.data())...);
hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides());
auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) {
data([&](auto... ps) {
auto outidx = out_desc.multi(i);
outp[i] = f(ps.second[ps.first.linear(outidx)]...);
});
});
});
});
};
}
template <class... Arguments>
auto nary_standard(argument result, Arguments... args)
{
return [=](auto f) {
// assert(x.get_shape().elements() == y.get_shape().elements());
auto output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
auto data = make_sequence(inputs.data()...);
auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) {
data([&](auto... xps) {
outp[i] = f(xps[i]...);
});
});
});
};
}
} // namespace device
} // namespace gpu
} // namespace migraph
#endif
......@@ -17,7 +17,7 @@ struct hip_add_relu
}
argument compute(context&, const shape&, const std::vector<argument>& args) const
{
device::add_relu(args.at(0), args.at(1), args.at(2));
device::add_relu(args.at(2), args.at(0), args.at(1));
return args.at(2);
}
};
......
......@@ -8,7 +8,7 @@ namespace migraph {
namespace gpu {
namespace device {
void add_relu(argument arg1, argument arg2, argument result);
void add_relu(argument result, argument arg1, argument arg2);
} // namespace device
} // namespace gpu
......
......@@ -7,7 +7,7 @@ namespace migraph {
namespace gpu {
namespace device {
void contiguous(argument arg, argument result);
void contiguous(argument result, argument arg);
} // namespace device
} // namespace gpu
......
......@@ -253,7 +253,7 @@ struct miopen_contiguous
assert(output_shape == args[1].get_shape());
assert(output_shape.standard());
(void)output_shape;
device::contiguous(args.at(0), args.at(1));
device::contiguous(args.at(1), args.at(0));
return args.at(1);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment