Commit dd651742 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refine unit tests

parent c07b91c8
...@@ -45,13 +45,16 @@ struct scatter ...@@ -45,13 +45,16 @@ struct scatter
{ {
argument result{output_shape}; argument result{output_shape};
// max dimension in axis // max dimension in axis
auto axis_dim_size = output_shape.lens()[axis];
visit_all(result, args[0], args[2])([&](auto output, auto data, auto update) { visit_all(result, args[0], args[2])([&](auto output, auto data, auto update) {
std::copy(data.begin(), data.end(), output.begin()); std::copy(data.begin(), data.end(), output.begin());
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
auto ind_s = indices.get_shape(); auto ind_s = indices.get_shape();
shape_for_each(ind_s, [&](const auto& idx) { shape_for_each(ind_s, [&](const auto& idx) {
auto out_idx = idx; auto out_idx = idx;
out_idx[axis] = indices[ind_s.index(idx)]; auto index = indices[ind_s.index(idx)];
index = (index < 0) ? index + axis_dim_size : index;
out_idx[axis] = index;
output[output_shape.index(out_idx)] = update[ind_s.index(idx)]; output[output_shape.index(out_idx)] = update[ind_s.index(idx)];
}); });
}); });
......
...@@ -36,6 +36,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -36,6 +36,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Relu", "relu"}, {"Relu", "relu"},
{"Round", "round"}, {"Round", "round"},
{"Scatter", "scatter"}, {"Scatter", "scatter"},
{"ScatterElements", "scatter"},
{"Sigmoid", "sigmoid"}, {"Sigmoid", "sigmoid"},
{"Sign", "sign"}, {"Sign", "sign"},
{"Sin", "sin"}, {"Sin", "sin"},
......
...@@ -15,6 +15,7 @@ argument scatter( ...@@ -15,6 +15,7 @@ argument scatter(
{ {
auto ds = arg0.get_shape(); auto ds = arg0.get_shape();
auto inds = arg1.get_shape(); auto inds = arg1.get_shape();
auto axis_dim_size = ds.lens()[axis];
hip_visit_all(result, arg0, inds)([&](auto output, auto data, auto s1) { hip_visit_all(result, arg0, inds)([&](auto output, auto data, auto s1) {
auto* output_ptr = device_cast(output.data()); auto* output_ptr = device_cast(output.data());
const auto* data_ptr = device_cast(data.data()); const auto* data_ptr = device_cast(data.data());
...@@ -25,7 +26,9 @@ argument scatter( ...@@ -25,7 +26,9 @@ argument scatter(
const auto* indices_ptr = device_cast(indices.data()); const auto* indices_ptr = device_cast(indices.data());
gs_launch(stream, inds.elements(), 256)([=](auto i) __device__ { gs_launch(stream, inds.elements(), 256)([=](auto i) __device__ {
auto out_idx = s1.multi(i); auto out_idx = s1.multi(i);
out_idx[axis] = indices_ptr[i]; auto index = indices_ptr[i];
index = index < 0 ? index + axis_dim_size : index;
out_idx[axis] = index;
output[out_idx] = upd_ptr[i]; output[out_idx] = upd_ptr[i];
}); });
}); });
......
...@@ -172,6 +172,8 @@ def create_backend_test(testname=None, target_device=None): ...@@ -172,6 +172,8 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_reduce.*') backend_test.include(r'.*test_reduce.*')
backend_test.include(r'.*test_ReLU*') backend_test.include(r'.*test_ReLU*')
backend_test.include(r'.*test_relu.*') backend_test.include(r'.*test_relu.*')
backend_test.include(r'.*test_scatter.*')
backend_test.include(r'.*test_Scatter.*')
backend_test.include(r'.*test_selu.*') backend_test.include(r'.*test_selu.*')
backend_test.include(r'.*test_shape.*') backend_test.include(r'.*test_shape.*')
backend_test.include(r'.*test_Sigmoid*') backend_test.include(r'.*test_Sigmoid*')
...@@ -272,6 +274,7 @@ def create_backend_test(testname=None, target_device=None): ...@@ -272,6 +274,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_mean_one_input_cpu') backend_test.exclude(r'test_mean_one_input_cpu')
backend_test.exclude(r'test_mean_two_inputs_cpu') backend_test.exclude(r'test_mean_two_inputs_cpu')
backend_test.exclude(r'test_negative_log_likelihood_loss_*') backend_test.exclude(r'test_negative_log_likelihood_loss_*')
backend_test.exclude(r'test_scatternd_*')
# all reduce ops have dynamic axes inputs # all reduce ops have dynamic axes inputs
backend_test.exclude(r'test_size_cpu') backend_test.exclude(r'test_size_cpu')
......
...@@ -3685,7 +3685,7 @@ TEST_CASE(scatter_test) ...@@ -3685,7 +3685,7 @@ TEST_CASE(scatter_test)
std::vector<float> vd(sd.elements(), 0.0f); std::vector<float> vd(sd.elements(), 0.0f);
migraphx::shape si{migraphx::shape::int32_type, {2, 3}}; migraphx::shape si{migraphx::shape::int32_type, {2, 3}};
std::vector<int> vi = {1, 0, 2, 0, 2, 1}; std::vector<int> vi = {1, 0, -1, 0, 2, -2};
migraphx::shape su{migraphx::shape::float_type, {2, 3}}; migraphx::shape su{migraphx::shape::float_type, {2, 3}};
std::vector<float> vu = {1.0, 1.1, 1.2, 2.0, 2.1, 2.2}; std::vector<float> vu = {1.0, 1.1, 1.2, 2.0, 2.1, 2.2};
......
...@@ -13,7 +13,7 @@ struct test_scatter1 : verify_program<test_scatter1> ...@@ -13,7 +13,7 @@ struct test_scatter1 : verify_program<test_scatter1>
migraphx::shape sd{migraphx::shape::float_type, {3, 3}}; migraphx::shape sd{migraphx::shape::float_type, {3, 3}};
migraphx::shape si{migraphx::shape::int32_type, {2, 3}}; migraphx::shape si{migraphx::shape::int32_type, {2, 3}};
std::vector<int> vi = {1, 0, 2, 0, 2, 1}; std::vector<int> vi = {-2, 0, 2, 0, -1, 1};
migraphx::shape su{migraphx::shape::float_type, {2, 3}}; migraphx::shape su{migraphx::shape::float_type, {2, 3}};
auto pd = mm->add_parameter("data", sd); auto pd = mm->add_parameter("data", sd);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment