Commit 6f96cf7e authored by Paul's avatar Paul
Browse files

Formatting

parent cff1144d
......@@ -7,9 +7,7 @@ namespace device {
void add_relu(argument arg1, argument arg2, argument result)
{
binary_standard(arg1, arg2, result, [](auto x, auto y) {
return max(0, x + y);
});
binary_standard(arg1, arg2, result, [](auto x, auto y) { return max(0, x + y); });
}
} // namespace device
......
......@@ -8,7 +8,7 @@ namespace migraph {
namespace gpu {
namespace device {
template<class F>
template <class F>
void binary(argument x, argument y, argument result, F f)
{
if(x.get_shape().standard())
......@@ -17,42 +17,40 @@ void binary(argument x, argument y, argument result, F f)
binary_nonstandard(x, y, result, f);
}
template<class F>
template <class F>
void binary_nonstandard(argument x, argument y, argument result, F f)
{
auto output_shape = result.get_shape();
auto input_shape = x.get_shape();
auto input_shape = x.get_shape();
visit_all(result, x, y)([&](auto output, auto input1, auto input2) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
hip_tensor_descriptor<ndim> x_desc(x.get_shape().lens(), x.get_shape().strides());
hip_tensor_descriptor<ndim> y_desc(y.get_shape().lens(), y.get_shape().strides());
hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides());
auto* xp = input1.data();
auto* yp = input2.data();
auto* xp = input1.data();
auto* yp = input2.data();
auto* outp = output.data();
gs_launch(input_shape.elements())([=](auto i) {
auto outidx = out_desc.multi(i);
size_t xidx = x_desc.linear(outidx);
size_t yidx = y_desc.linear(outidx);
outp[i] = f(xp[xidx], yp[yidx]);
outp[i] = f(xp[xidx], yp[yidx]);
});
});
});
}
template<class F>
template <class F>
void binary_standard(argument x, argument y, argument result, F f)
{
assert(x.get_shape().elements() == y.get_shape().elements());
auto output_shape = result.get_shape();
auto input_shape = x.get_shape();
auto input_shape = x.get_shape();
visit_all(result, x, y)([&](auto output, auto input1, auto input2) {
auto* xp = input1.data();
auto* yp = input2.data();
auto* xp = input1.data();
auto* yp = input2.data();
auto* outp = output.data();
gs_launch(input_shape.elements())([=](auto i) {
outp[i] = f(xp[i], yp[i]);
});
gs_launch(input_shape.elements())([=](auto i) { outp[i] = f(xp[i], yp[i]); });
});
}
......
......@@ -9,7 +9,7 @@ namespace migraph {
namespace gpu {
namespace device {
template<class F>
template <class F>
void unary(argument x, argument result, F f)
{
if(x.get_shape().standard())
......@@ -18,36 +18,34 @@ void unary(argument x, argument result, F f)
unary_nonstandard(x, result, f);
}
template<class F>
template <class F>
void unary_nonstandard(argument x, argument result, F f)
{
auto output_shape = result.get_shape();
auto input_shape = x.get_shape();
auto input_shape = x.get_shape();
visit_all(result, x)([&](auto output, auto input) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
hip_tensor_descriptor<ndim> x_desc(input_shape.lens(), input_shape.strides());
hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides());
auto* xp = input.data();
auto* xp = input.data();
auto* outp = output.data();
gs_launch(input_shape.elements())([=](auto i) {
size_t xidx = x_desc.linear(out_desc.multi(i));
outp[i] = f(xp[xidx]);
outp[i] = f(xp[xidx]);
});
});
});
}
template<class F>
template <class F>
void unary_standard(argument x, argument result, F f)
{
auto output_shape = result.get_shape();
auto input_shape = x.get_shape();
auto input_shape = x.get_shape();
visit_all(result, x)([&](auto output, auto input) {
auto* xp = input.data();
auto* xp = input.data();
auto* outp = output.data();
gs_launch(input_shape.elements())([=](auto i) {
outp[i] = f(xp[i]);
});
gs_launch(input_shape.elements())([=](auto i) { outp[i] = f(xp[i]); });
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment