Commit 27c0ae08 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent dd651742
......@@ -51,9 +51,9 @@ struct scatter
args[1].visit([&](auto indices) {
auto ind_s = indices.get_shape();
shape_for_each(ind_s, [&](const auto& idx) {
auto out_idx = idx;
auto index = indices[ind_s.index(idx)];
index = (index < 0) ? index + axis_dim_size : index;
auto out_idx = idx;
auto index = indices[ind_s.index(idx)];
index = (index < 0) ? index + axis_dim_size : index;
out_idx[axis] = index;
output[output_shape.index(out_idx)] = update[ind_s.index(idx)];
});
......
......@@ -13,8 +13,8 @@ namespace device {
argument scatter(
hipStream_t stream, argument result, argument arg0, argument arg1, argument arg2, int64_t axis)
{
auto ds = arg0.get_shape();
auto inds = arg1.get_shape();
auto ds = arg0.get_shape();
auto inds = arg1.get_shape();
auto axis_dim_size = ds.lens()[axis];
hip_visit_all(result, arg0, inds)([&](auto output, auto data, auto s1) {
auto* output_ptr = device_cast(output.data());
......@@ -26,8 +26,8 @@ argument scatter(
const auto* indices_ptr = device_cast(indices.data());
gs_launch(stream, inds.elements(), 256)([=](auto i) __device__ {
auto out_idx = s1.multi(i);
auto index = indices_ptr[i];
index = index < 0 ? index + axis_dim_size : index;
auto index = indices_ptr[i];
index = index < 0 ? index + axis_dim_size : index;
out_idx[axis] = index;
output[out_idx] = upd_ptr[i];
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment