"...git@developer.sourcefind.cn:jerrrrry/dcu_megatron.git" did not exist on "4e867b3c37ca3ee38602c31011ced53e41f4071a"
Commit c7888300 authored by Paul's avatar Paul
Browse files

Formatting

parent a919e88a
...@@ -34,12 +34,10 @@ auto fix(F f) ...@@ -34,12 +34,10 @@ auto fix(F f)
return fix<void>(f); return fix<void>(f);
} }
template<class... Ts> template <class... Ts>
auto make_sequence(Ts... xs) auto make_sequence(Ts... xs)
{ {
return [=](auto f) { return [=](auto f) { return f(xs...); };
return f(xs...);
};
} }
} // namespace migraph } // namespace migraph
......
...@@ -29,7 +29,10 @@ auto nary_nonstandard(argument result, Arguments... args) ...@@ -29,7 +29,10 @@ auto nary_nonstandard(argument result, Arguments... args)
auto output_shape = result.get_shape(); auto output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = make_sequence(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape().lens(), inputs.get_shape().strides()}, inputs.data())...); auto data = make_sequence(
std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape().lens(),
inputs.get_shape().strides()},
inputs.data())...);
hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides()); hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides());
auto* outp = output.data(); auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) { gs_launch(output_shape.elements())([=](auto i) {
...@@ -52,11 +55,8 @@ auto nary_standard(argument result, Arguments... args) ...@@ -52,11 +55,8 @@ auto nary_standard(argument result, Arguments... args)
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
auto data = make_sequence(inputs.data()...); auto data = make_sequence(inputs.data()...);
auto* outp = output.data(); auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) { gs_launch(output_shape.elements())(
data([&](auto... xps) { [=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
outp[i] = f(xps[i]...);
});
});
}); });
}; };
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment