Commit d37f0677 authored by Paul's avatar Paul
Browse files

Formatting

parent 442581b9
......@@ -16,9 +16,8 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data =
pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()},
inputs.data())...);
auto data = pack(
std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()}, inputs.data())...);
hip_tensor_descriptor<ndim> out_desc(output_shape);
auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) {
......@@ -41,26 +40,34 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
{
return [=](auto f) {
const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape();
auto bdim = std::distance(b_shape.strides().begin(), std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) { return x != 0; }));
auto bdim_len = b_shape.lens()[bdim];
auto outer_size = std::accumulate(output_shape.lens().begin(), output_shape.lens().begin() + bdim + 1, std::size_t{1}, std::multiplies<>{});
auto inner_size = std::accumulate(output_shape.lens().begin()+bdim+1, output_shape.lens().end(), std::size_t{1}, std::multiplies<>{});
const auto& b_shape = arg2.get_shape();
auto bdim = std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(),
b_shape.strides().end(),
[](auto x) { return x != 0; }));
auto bdim_len = b_shape.lens()[bdim];
auto outer_size = std::accumulate(output_shape.lens().begin(),
output_shape.lens().begin() + bdim + 1,
std::size_t{1},
std::multiplies<>{});
auto inner_size = std::accumulate(output_shape.lens().begin() + bdim + 1,
output_shape.lens().end(),
std::size_t{1},
std::multiplies<>{});
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
auto* xp = input1.data();
auto* yp = input2.data();
auto* xp = input1.data();
auto* yp = input2.data();
auto* outp = output.data();
gs_launch(outer_size)(
[=](auto i) {
auto * outp2 = outp + i;
auto * xp2 = xp + i;
auto b = yp[i % bdim_len];
for(std::size_t j = 0;j < inner_size;j++)
{
outp2[j] = f(xp2[j], b);
}
});
gs_launch(outer_size)([=](auto i) {
auto* outp2 = outp + i;
auto* xp2 = xp + i;
auto b = yp[i % bdim_len];
for(std::size_t j = 0; j < inner_size; j++)
{
outp2[j] = f(xp2[j], b);
}
});
});
};
}
......@@ -85,8 +92,9 @@ auto nary_impl(argument result, Arguments... args)
{
return [=](auto f) {
bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); });
bool same_shapes = all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); });
bool same_shapes =
all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
if(standard or (packed and same_shapes))
nary_standard(result, args...)(f);
else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment