Commit dd651742 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refine unit tests

parent c07b91c8
......@@ -45,13 +45,16 @@ struct scatter
{
argument result{output_shape};
// max dimension in axis
auto axis_dim_size = output_shape.lens()[axis];
visit_all(result, args[0], args[2])([&](auto output, auto data, auto update) {
std::copy(data.begin(), data.end(), output.begin());
args[1].visit([&](auto indices) {
auto ind_s = indices.get_shape();
shape_for_each(ind_s, [&](const auto& idx) {
auto out_idx = idx;
out_idx[axis] = indices[ind_s.index(idx)];
auto index = indices[ind_s.index(idx)];
index = (index < 0) ? index + axis_dim_size : index;
out_idx[axis] = index;
output[output_shape.index(out_idx)] = update[ind_s.index(idx)];
});
});
......
......@@ -36,6 +36,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Relu", "relu"},
{"Round", "round"},
{"Scatter", "scatter"},
{"ScatterElements", "scatter"},
{"Sigmoid", "sigmoid"},
{"Sign", "sign"},
{"Sin", "sin"},
......
......@@ -15,6 +15,7 @@ argument scatter(
{
auto ds = arg0.get_shape();
auto inds = arg1.get_shape();
auto axis_dim_size = ds.lens()[axis];
hip_visit_all(result, arg0, inds)([&](auto output, auto data, auto s1) {
auto* output_ptr = device_cast(output.data());
const auto* data_ptr = device_cast(data.data());
......@@ -25,7 +26,9 @@ argument scatter(
const auto* indices_ptr = device_cast(indices.data());
gs_launch(stream, inds.elements(), 256)([=](auto i) __device__ {
auto out_idx = s1.multi(i);
out_idx[axis] = indices_ptr[i];
auto index = indices_ptr[i];
index = index < 0 ? index + axis_dim_size : index;
out_idx[axis] = index;
output[out_idx] = upd_ptr[i];
});
});
......
......@@ -172,6 +172,8 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_reduce.*')
backend_test.include(r'.*test_ReLU*')
backend_test.include(r'.*test_relu.*')
backend_test.include(r'.*test_scatter.*')
backend_test.include(r'.*test_Scatter.*')
backend_test.include(r'.*test_selu.*')
backend_test.include(r'.*test_shape.*')
backend_test.include(r'.*test_Sigmoid*')
......@@ -272,6 +274,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_mean_one_input_cpu')
backend_test.exclude(r'test_mean_two_inputs_cpu')
backend_test.exclude(r'test_negative_log_likelihood_loss_*')
backend_test.exclude(r'test_scatternd_*')
# all reduce ops have dynamic axes inputs
backend_test.exclude(r'test_size_cpu')
......
......@@ -3685,7 +3685,7 @@ TEST_CASE(scatter_test)
std::vector<float> vd(sd.elements(), 0.0f);
migraphx::shape si{migraphx::shape::int32_type, {2, 3}};
std::vector<int> vi = {1, 0, 2, 0, 2, 1};
std::vector<int> vi = {1, 0, -1, 0, 2, -2};
migraphx::shape su{migraphx::shape::float_type, {2, 3}};
std::vector<float> vu = {1.0, 1.1, 1.2, 2.0, 2.1, 2.2};
......
......@@ -13,7 +13,7 @@ struct test_scatter1 : verify_program<test_scatter1>
migraphx::shape sd{migraphx::shape::float_type, {3, 3}};
migraphx::shape si{migraphx::shape::int32_type, {2, 3}};
std::vector<int> vi = {1, 0, 2, 0, 2, 1};
std::vector<int> vi = {-2, 0, 2, 0, -1, 1};
migraphx::shape su{migraphx::shape::float_type, {2, 3}};
auto pd = mm->add_parameter("data", sd);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment