".github/vscode:/vscode.git/clone" did not exist on "bad334fa5b9dd9d8efb98c2cd1b1bfe533434322"
Commit a919e88a authored by Paul's avatar Paul
Browse files

Add nary to handle any number of arguments

parent 032d0650
...@@ -34,6 +34,14 @@ auto fix(F f) ...@@ -34,6 +34,14 @@ auto fix(F f)
return fix<void>(f); return fix<void>(f);
} }
template<class... Ts>
auto make_sequence(Ts... xs)
{
return [=](auto f) {
return f(xs...);
};
}
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -60,6 +60,18 @@ bool contains(const std::initializer_list<T>& c, const U& x) ...@@ -60,6 +60,18 @@ bool contains(const std::initializer_list<T>& c, const U& x)
return generic_find(c, x) != c.end(); return generic_find(c, x) != c.end();
} }
template <class C, class Predicate>
bool all_of(const C& c, const Predicate& p)
{
return std::all_of(c.begin(), c.end(), p);
}
template <class T, class Predicate>
bool all_of(const std::initializer_list<T>& c, const Predicate& p)
{
return std::all_of(c.begin(), c.end(), p);
}
template <class Range, class Iterator> template <class Range, class Iterator>
void copy(Range&& r, Iterator it) void copy(Range&& r, Iterator it)
{ {
......
#include <migraph/gpu/device/contiguous.hpp> #include <migraph/gpu/device/add_relu.hpp>
#include <migraph/gpu/device/binary.hpp> #include <migraph/gpu/device/nary.hpp>
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
void add_relu(argument arg1, argument arg2, argument result) void add_relu(argument result, argument arg1, argument arg2)
{ {
binary_standard(arg1, arg2, result, [](auto x, auto y) { return max(0, x + y); }); nary_standard(result, arg1, arg2)([](auto x, auto y) { return max(0, x + y); });
} }
} // namespace device } // namespace device
......
#include <migraph/gpu/device/contiguous.hpp> #include <migraph/gpu/device/contiguous.hpp>
#include <migraph/gpu/device/unary.hpp> #include <migraph/gpu/device/nary.hpp>
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
void contiguous(argument arg, argument result) void contiguous(argument result, argument arg)
{ {
unary_nonstandard(arg, result, [](auto x) { return x; }); nary_nonstandard(result, arg)([](auto x) { return x; });
} }
} // namespace device } // namespace device
......
#ifndef MIGRAPH_GUARD_RTGLIB_DEVICE_NARY_HPP
#define MIGRAPH_GUARD_RTGLIB_DEVICE_NARY_HPP
#include <migraph/gpu/device/tensor.hpp>
#include <migraph/gpu/device/launch.hpp>
#include <migraph/functional.hpp>
#include <migraph/ranges.hpp>
namespace migraph {
namespace gpu {
namespace device {
template <class... Arguments>
auto nary(argument result, Arguments... args)
{
return [=](auto f) {
if(all_of({args...}, [](const shape& s) { return s.standard(); }))
nary_standard(result, args...)(f);
else
nary_nonstandard(result, args...)(f);
};
}
template <class... Arguments>
auto nary_nonstandard(argument result, Arguments... args)
{
return [=](auto f) {
auto output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = make_sequence(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape().lens(), inputs.get_shape().strides()}, inputs.data())...);
hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides());
auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) {
data([&](auto... ps) {
auto outidx = out_desc.multi(i);
outp[i] = f(ps.second[ps.first.linear(outidx)]...);
});
});
});
});
};
}
template <class... Arguments>
auto nary_standard(argument result, Arguments... args)
{
return [=](auto f) {
// assert(x.get_shape().elements() == y.get_shape().elements());
auto output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
auto data = make_sequence(inputs.data()...);
auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) {
data([&](auto... xps) {
outp[i] = f(xps[i]...);
});
});
});
};
}
} // namespace device
} // namespace gpu
} // namespace migraph
#endif
...@@ -17,7 +17,7 @@ struct hip_add_relu ...@@ -17,7 +17,7 @@ struct hip_add_relu
} }
argument compute(context&, const shape&, const std::vector<argument>& args) const argument compute(context&, const shape&, const std::vector<argument>& args) const
{ {
device::add_relu(args.at(0), args.at(1), args.at(2)); device::add_relu(args.at(2), args.at(0), args.at(1));
return args.at(2); return args.at(2);
} }
}; };
......
...@@ -8,7 +8,7 @@ namespace migraph { ...@@ -8,7 +8,7 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
void add_relu(argument arg1, argument arg2, argument result); void add_relu(argument result, argument arg1, argument arg2);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -7,7 +7,7 @@ namespace migraph { ...@@ -7,7 +7,7 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
void contiguous(argument arg, argument result); void contiguous(argument result, argument arg);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -253,7 +253,7 @@ struct miopen_contiguous ...@@ -253,7 +253,7 @@ struct miopen_contiguous
assert(output_shape == args[1].get_shape()); assert(output_shape == args[1].get_shape());
assert(output_shape.standard()); assert(output_shape.standard());
(void)output_shape; (void)output_shape;
device::contiguous(args.at(0), args.at(1)); device::contiguous(args.at(1), args.at(0));
return args.at(1); return args.at(1);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment