"docs/source/nas/overview.rst" did not exist on "fac7364a09d03f46a660aec7e6f3c911ed223336"
Commit 6ce07611 authored by Paul's avatar Paul
Browse files

Formatting

parent b5c0f7ef
......@@ -63,7 +63,7 @@ constexpr void repeat_c_impl(F f, seq<Ns...>)
} // namespace detail
template<std::size_t N, class F>
template <std::size_t N, class F>
constexpr void repeat_c(F f)
{
detail::repeat_c_impl(f, detail::gens<N>{});
......
......@@ -16,9 +16,8 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data =
pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()},
inputs.data())...);
auto data = pack(
std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()}, inputs.data())...);
hip_tensor_descriptor<ndim> out_desc(output_shape);
auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) {
......@@ -57,8 +56,9 @@ auto nary(argument result, Arguments... args)
{
return [=](auto f) {
bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); });
bool same_shapes = all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); });
bool same_shapes =
all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
if(standard or (packed and same_shapes))
nary_standard(result, args...)(f);
else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment