Commit d15edcb6 authored by Paul's avatar Paul
Browse files

Formatting

parent ce3048d4
......@@ -13,22 +13,22 @@ namespace device {
argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int axis)
{
auto axis_index = (axis < 0) ? (axis + arg1.get_shape().lens().size()) : axis;
auto& input_shape = arg1.get_shape();
auto lens = input_shape.lens();
lens[axis_index] = arg2.get_shape().elements();
auto axis_index = (axis < 0) ? (axis + arg1.get_shape().lens().size()) : axis;
auto& input_shape = arg1.get_shape();
auto lens = input_shape.lens();
lens[axis_index] = arg2.get_shape().elements();
shape out_comp_shape{result.get_shape().type(), lens};
std::size_t nelements = result.get_shape().elements();
visit_all(result, arg1)([&](auto output, auto input_v) {
hip_visit_views(input_v, out_comp_shape)([&](auto input, auto out_comp) {
arg2.visit([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data());
auto* output_ptr = device_cast(output.data());
auto* output_ptr = device_cast(output.data());
gs_launch(stream, nelements)([=](auto i) {
auto idx = out_comp.multi(i);
idx[axis_index] = indices_ptr[idx[axis_index]];
output_ptr[i] = input[idx];
output_ptr[i] = input[idx];
});
});
});
......
......@@ -55,10 +55,12 @@ template <class V, class F, class... Ts>
void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
{
std::initializer_list<migraphx::shape::type_t> types = {get_shape(xs).type()...};
if(!std::all_of(types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
if(!std::all_of(
types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
std::initializer_list<std::size_t> ranks = {get_shape(xs).lens().size()...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); }))
if(!std::all_of(
ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(),
[&](auto ndim) { s.visit_type([&](auto as) { v(f(xs, ndim, as)...); }); });
......@@ -68,10 +70,10 @@ template <class V, class F, class... Ts>
void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{
std::initializer_list<std::size_t> ranks = {get_shape(xs).lens().size()...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); }))
if(!std::all_of(
ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(),
[&](auto ndim) { v(f(xs, ndim)...); });
visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); });
}
template <class F>
......@@ -152,8 +154,11 @@ template <class T, class... Ts>
auto hip_visit_views(T&& x, Ts&&... xs)
{
return [&](auto f) {
hip_visit_views_impl(
get_shape(x), make_hip_convert_view([](auto v) { return device_cast(v); }), f, x, xs...);
hip_visit_views_impl(get_shape(x),
make_hip_convert_view([](auto v) { return device_cast(v); }),
f,
x,
xs...);
};
}
......
......@@ -12,9 +12,7 @@ shape hip_gather::compute_shape(std::vector<shape> inputs) const
return op.compute_shape(inputs);
}
argument hip_gather::compute(context& ctx,
const shape&,
const std::vector<argument>& args) const
argument hip_gather::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
return device::gather(ctx.get_stream().get(), args.back(), args[0], args[1], op.axis);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment