Commit d15edcb6 authored by Paul's avatar Paul
Browse files

Formatting

parent ce3048d4
......@@ -55,10 +55,12 @@ template <class V, class F, class... Ts>
void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
{
std::initializer_list<migraphx::shape::type_t> types = {get_shape(xs).type()...};
if(!std::all_of(types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
if(!std::all_of(
types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
std::initializer_list<std::size_t> ranks = {get_shape(xs).lens().size()...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); }))
if(!std::all_of(
ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(),
[&](auto ndim) { s.visit_type([&](auto as) { v(f(xs, ndim, as)...); }); });
......@@ -68,10 +70,10 @@ template <class V, class F, class... Ts>
void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{
std::initializer_list<std::size_t> ranks = {get_shape(xs).lens().size()...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); }))
if(!std::all_of(
ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(),
[&](auto ndim) { v(f(xs, ndim)...); });
visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); });
}
template <class F>
......@@ -152,8 +154,11 @@ template <class T, class... Ts>
auto hip_visit_views(T&& x, Ts&&... xs)
{
return [&](auto f) {
hip_visit_views_impl(
get_shape(x), make_hip_convert_view([](auto v) { return device_cast(v); }), f, x, xs...);
hip_visit_views_impl(get_shape(x),
make_hip_convert_view([](auto v) { return device_cast(v); }),
f,
x,
xs...);
};
}
......
......@@ -12,9 +12,7 @@ shape hip_gather::compute_shape(std::vector<shape> inputs) const
return op.compute_shape(inputs);
}
argument hip_gather::compute(context& ctx,
const shape&,
const std::vector<argument>& args) const
argument hip_gather::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
return device::gather(ctx.get_stream().get(), args.back(), args[0], args[1], op.axis);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment