Commit 6ce07611 authored by Paul's avatar Paul
Browse files

Formatting

parent b5c0f7ef
...@@ -63,7 +63,7 @@ constexpr void repeat_c_impl(F f, seq<Ns...>) ...@@ -63,7 +63,7 @@ constexpr void repeat_c_impl(F f, seq<Ns...>)
} // namespace detail } // namespace detail
template<std::size_t N, class F> template <std::size_t N, class F>
constexpr void repeat_c(F f) constexpr void repeat_c(F f)
{ {
detail::repeat_c_impl(f, detail::gens<N>{}); detail::repeat_c_impl(f, detail::gens<N>{});
......
...@@ -16,9 +16,8 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args) ...@@ -16,9 +16,8 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = auto data = pack(
pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()}, std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()}, inputs.data())...);
inputs.data())...);
hip_tensor_descriptor<ndim> out_desc(output_shape); hip_tensor_descriptor<ndim> out_desc(output_shape);
auto* outp = output.data(); auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) { gs_launch(output_shape.elements())([=](auto i) {
...@@ -57,8 +56,9 @@ auto nary(argument result, Arguments... args) ...@@ -57,8 +56,9 @@ auto nary(argument result, Arguments... args)
{ {
return [=](auto f) { return [=](auto f) {
bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); }); bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); }); bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); });
bool same_shapes = all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); }); bool same_shapes =
all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
if(standard or (packed and same_shapes)) if(standard or (packed and same_shapes))
nary_standard(result, args...)(f); nary_standard(result, args...)(f);
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment