Commit c1ec929c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents abe2a889 03225b57
lpnormalization_l1_test:f
!
xy"LpNormalization*
plpnormalization_l1_testZ
x


b
y


B
\ No newline at end of file
lpnormalization_l2_test:f
!
xy"LpNormalization*
plpnormalization_l2_testZ
x


b
y


B
\ No newline at end of file
lpnormalization_p_error_test:k
!
xy"LpNormalization*
plpnormalization_p_error_testZ
x


b
y


B
\ No newline at end of file
This diff is collapsed.
size_float_test:I
xy"Sizesize_float_testZ
x



b
y

B
\ No newline at end of file
size_half_test:D
xy"Sizesize_half_testZ
x



b
y

B
\ No newline at end of file
 size_int_test:G
xy"Size size_int_testZ
x



b
y

B
\ No newline at end of file
size_verify_test:J
xy"Sizesize_verify_testZ
x



b
y

B
\ No newline at end of file
...@@ -45,6 +45,44 @@ TEST_CASE(averagepool_nt_cip_test) ...@@ -45,6 +45,44 @@ TEST_CASE(averagepool_nt_cip_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(celu_verify_test)
{
migraphx::program p = migraphx::parse_onnx("celu_verify_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {-5.5, 2.0, 100., 7.0, 0., -1.};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> correct(6);
float alpha = 0.5;
std::transform(data.begin(), data.end(), correct.begin(), [&](auto x) {
return std::max(0.0f, x) + std::min(0.0f, alpha * std::expm1(x / alpha));
});
EXPECT(migraphx::verify_range(result_vector, correct));
}
TEST_CASE(clip_args_type_mismatch)
{
auto p = migraphx::parse_onnx("clip_test_args_type_mismatch.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_0{migraphx::shape::float_type, {3, 3}};
migraphx::parameter_map pp;
std::vector<float> data_0 = {0.9, 1.2, 1.7, 1.9, 2.2, 2.7, 2.9, 3.2, 3.7};
pp["0"] = migraphx::argument(s_0, data_0.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.5, 2, 2, 1.9, 2.5, 3, 2.9, 3.2, 3.7};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(depthtospace_simple_test) TEST_CASE(depthtospace_simple_test)
{ {
auto p = migraphx::parse_onnx("depthtospace_simple_test.onnx"); auto p = migraphx::parse_onnx("depthtospace_simple_test.onnx");
...@@ -103,6 +141,42 @@ TEST_CASE(spacetodepth_depthtospace_test) ...@@ -103,6 +141,42 @@ TEST_CASE(spacetodepth_depthtospace_test)
EXPECT(migraphx::verify_range(result_vector2, data_in)); EXPECT(migraphx::verify_range(result_vector2, data_in));
} }
TEST_CASE(eyelike_verify_test)
{
migraphx::program p = migraphx::parse_onnx("eyelike_verify_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {3, 4}};
std::vector<float> data{12, 0};
migraphx::parameter_map pp;
pp["T1"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> eyelike_mat = {0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.};
EXPECT(migraphx::verify_range(result_vector, eyelike_mat));
}
TEST_CASE(eyelike_verify_negk_test)
{
migraphx::program p = migraphx::parse_onnx("eyelike_verify_negk_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {3, 4}};
std::vector<float> data{12, 0};
migraphx::parameter_map pp;
pp["T1"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> eyelike_mat = {0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.};
EXPECT(migraphx::verify_range(result_vector, eyelike_mat));
}
TEST_CASE(gather_elements) TEST_CASE(gather_elements)
{ {
migraphx::program p = migraphx::parse_onnx("gather_elements_axis0_test.onnx"); migraphx::program p = migraphx::parse_onnx("gather_elements_axis0_test.onnx");
...@@ -393,6 +467,62 @@ TEST_CASE(lessorequal_test) ...@@ -393,6 +467,62 @@ TEST_CASE(lessorequal_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(lpnormalization_1norm)
{
migraphx::program p = migraphx::parse_onnx("lpnormalization_l1_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {3, 4}};
std::vector<float> data{0.f, 2.f, -2.f, 1.f, 1.f, -5.f, 3.f, -1.f, -4.f, 3.f, 0.f, 0.f};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.f,
2.f / 5.f,
-2.f / 5.f,
1.f / 5.f,
1.f / 10.f,
-5.f / 10.f,
3.f / 10.f,
-1.f / 10.f,
-4.f / 7.f,
3.f / 7.f,
0.f,
0.f};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(lpnormalization_2norm)
{
migraphx::program p = migraphx::parse_onnx("lpnormalization_l2_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {3, 4}};
std::vector<float> data{0.f, 2.f, -2.f, 1.f, 1.f, -5.f, 3.f, -1.f, -4.f, 3.f, 0.f, 0.f};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> correct{0.f,
2.f / 3.f,
-2.f / 3.f,
1.f / 3.f,
1.f / 6.f,
-5.f / 6.f,
3.f / 6.f,
-1.f / 6.f,
-4.f / 5.f,
3.f / 5.f,
0.f,
0.f};
EXPECT(migraphx::verify_range(result_vector, correct));
}
TEST_CASE(mean_broadcast_test) TEST_CASE(mean_broadcast_test)
{ {
migraphx::program p = migraphx::parse_onnx("mean_broadcast_test.onnx"); migraphx::program p = migraphx::parse_onnx("mean_broadcast_test.onnx");
...@@ -588,6 +718,21 @@ TEST_CASE(selu_test) ...@@ -588,6 +718,21 @@ TEST_CASE(selu_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(size_verify_test)
{
migraphx::program p = migraphx::parse_onnx("size_verify_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {2, 5, 3}};
std::vector<float> data(30, 1.);
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
auto size_result = result.at<int64_t>();
EXPECT(size_result == int64_t{30});
}
TEST_CASE(slice_test) TEST_CASE(slice_test)
{ {
migraphx::program p = migraphx::parse_onnx("slice_test.onnx"); migraphx::program p = migraphx::parse_onnx("slice_test.onnx");
......
...@@ -570,10 +570,12 @@ TEST_CASE(inconsistent_attr_shape) ...@@ -570,10 +570,12 @@ TEST_CASE(inconsistent_attr_shape)
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}), {{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input, input,
weights); weights);
throws_shape( throws_shape(migraphx::make_op("pooling",
migraphx::make_op( {{"mode", migraphx::op::pooling_mode::max},
"pooling", {{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1, 1}}}), {"padding", {1}},
input); {"stride", {0}},
{"lengths", {1, 1}}}),
input);
} }
template <class T> template <class T>
...@@ -983,21 +985,24 @@ TEST_CASE(pooling_shape) ...@@ -983,21 +985,24 @@ TEST_CASE(pooling_shape)
{ {
migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}}; migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}}; migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape( throws_shape(migraphx::make_op("pooling",
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max},
{{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1}}}), {"padding", {1}},
input); {"stride", {0}},
expect_shape( {"lengths", {1}}}),
output, input);
migraphx::make_op( expect_shape(output,
"pooling", migraphx::make_op("pooling",
{{"mode", "max"}, {"padding", {0, 0}}, {"stride", {3, 3}}, {"lengths", {1, 1}}}), {{"mode", migraphx::op::pooling_mode::max},
input); {"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}}}),
input);
migraphx::shape output1{migraphx::shape::float_type, {4, 3, 2, 2}}; migraphx::shape output1{migraphx::shape::float_type, {4, 3, 2, 2}};
expect_shape(output1, expect_shape(output1,
migraphx::make_op("pooling", migraphx::make_op("pooling",
{{"mode", "max"}, {{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}}, {"padding", {0, 0}},
{"stride", {3, 3}}, {"stride", {3, 3}},
{"lengths", {1, 1}}, {"lengths", {1, 1}},
......
...@@ -25,10 +25,11 @@ endforeach() ...@@ -25,10 +25,11 @@ endforeach()
add_py_test(ref test_cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(ref test_cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(save_load test_save_load.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(save_load test_save_load.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(op test_op.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(shape test_shape.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
add_py_test(gpu_offload test_gpu_offload.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(gpu_offload test_gpu_offload.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(gpu test_gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(gpu test_gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(array test_array.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(array test_array.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(backend onnx_backend_test.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(backend onnx_backend_test.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(op test_op.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
endif() endif()
...@@ -96,6 +96,7 @@ def create_backend_test(testname=None, target_device=None): ...@@ -96,6 +96,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_AvgPool.*') backend_test.include(r'.*test_AvgPool.*')
backend_test.include(r'.*test_BatchNorm.*eval.*') backend_test.include(r'.*test_BatchNorm.*eval.*')
backend_test.include(r'.*test_ceil.*') backend_test.include(r'.*test_ceil.*')
backend_test.include(r'.*test_celu.*')
backend_test.include(r'.*test_clip.*') backend_test.include(r'.*test_clip.*')
backend_test.include(r'.*test_concat.*') backend_test.include(r'.*test_concat.*')
backend_test.include(r'.*test_constant.*') backend_test.include(r'.*test_constant.*')
...@@ -111,6 +112,7 @@ def create_backend_test(testname=None, target_device=None): ...@@ -111,6 +112,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_equal.*') backend_test.include(r'.*test_equal.*')
backend_test.include(r'.*test_Embedding*') backend_test.include(r'.*test_Embedding*')
backend_test.include(r'.*test_exp.*') backend_test.include(r'.*test_exp.*')
backend_test.include(r'.*test_eyelike.*')
backend_test.include(r'.*test_flatten.*') backend_test.include(r'.*test_flatten.*')
backend_test.include(r'.*test_floor.*') backend_test.include(r'.*test_floor.*')
backend_test.include(r'.*test_gather.*') backend_test.include(r'.*test_gather.*')
...@@ -273,8 +275,6 @@ def create_backend_test(testname=None, target_device=None): ...@@ -273,8 +275,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_negative_log_likelihood_loss_*') backend_test.exclude(r'test_negative_log_likelihood_loss_*')
# all reduce ops have dynamic axes inputs # all reduce ops have dynamic axes inputs
backend_test.exclude(r'test_size_cpu')
backend_test.exclude(r'test_size_example_cpu')
backend_test.exclude(r'test_softmax_cross_entropy_*') backend_test.exclude(r'test_softmax_cross_entropy_*')
backend_test.exclude(r'test_Embedding_cpu') backend_test.exclude(r'test_Embedding_cpu')
......
import migraphx
def test_create_shape():
s = migraphx.shape(lens=[1, 64, 3, 3])
assert s.standard()
assert s.packed()
assert s.lens() == [1, 64, 3, 3]
def test_create_shape_broadcast():
s = migraphx.shape(lens=[1, 64, 3, 3], strides=[0, 1, 0, 0])
assert s.broadcasted()
assert s.lens() == [1, 64, 3, 3]
assert s.strides() == [0, 1, 0, 0]
def test_create_shape_type():
s = migraphx.shape(type='uint8')
assert s.type_string() == 'uint8_type'
assert s.type_size() == 1
...@@ -370,7 +370,7 @@ TEST_CASE(avgpool_test) ...@@ -370,7 +370,7 @@ TEST_CASE(avgpool_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{"average"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2}; op.lengths = {2};
op.padding = {0}; op.padding = {0};
op.stride = {1}; op.stride = {1};
...@@ -392,7 +392,7 @@ TEST_CASE(avgpool_test) ...@@ -392,7 +392,7 @@ TEST_CASE(avgpool_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 4}};
auto op = migraphx::op::pooling{"average"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2}; op.lengths = {2};
op.padding = {1}; op.padding = {1};
op.stride = {2}; op.stride = {2};
...@@ -439,7 +439,7 @@ TEST_CASE(avgpool_test) ...@@ -439,7 +439,7 @@ TEST_CASE(avgpool_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}};
auto op = migraphx::op::pooling{"average"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2, 2, 2}; op.lengths = {2, 2, 2};
op.padding = {0, 0, 0}; op.padding = {0, 0, 0};
op.stride = {1, 1, 1}; op.stride = {1, 1, 1};
...@@ -1658,7 +1658,7 @@ TEST_CASE(globalavgpool_test) ...@@ -1658,7 +1658,7 @@ TEST_CASE(globalavgpool_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}};
auto op = migraphx::op::pooling{"average"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
auto lens = s.lens(); auto lens = s.lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
...@@ -1679,7 +1679,7 @@ TEST_CASE(globalmaxpool_test) ...@@ -1679,7 +1679,7 @@ TEST_CASE(globalmaxpool_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}};
auto op = migraphx::op::pooling{"max"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
auto lens = s.lens(); auto lens = s.lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
...@@ -2582,11 +2582,12 @@ TEST_CASE(maxpool_test) ...@@ -2582,11 +2582,12 @@ TEST_CASE(maxpool_test)
0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802}; 0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 6, 6}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 6, 6}};
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
mm->add_instruction( mm->add_instruction(migraphx::make_op("pooling",
migraphx::make_op( {{"mode", migraphx::op::pooling_mode::max},
"pooling", {"padding", {0, 0}},
{{"mode", "max"}, {"padding", {0, 0}}, {"stride", {2, 2}}, {"lengths", {3, 2}}}), {"stride", {2, 2}},
al); {"lengths", {3, 2}}}),
al);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(36); std::vector<float> results_vector(36);
...@@ -2601,7 +2602,7 @@ TEST_CASE(maxpool_test_1D_3D) ...@@ -2601,7 +2602,7 @@ TEST_CASE(maxpool_test_1D_3D)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{"max"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2}; op.lengths = {2};
op.padding = {0}; op.padding = {0};
op.stride = {1}; op.stride = {1};
...@@ -2623,7 +2624,7 @@ TEST_CASE(maxpool_test_1D_3D) ...@@ -2623,7 +2624,7 @@ TEST_CASE(maxpool_test_1D_3D)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}}; auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}};
auto op = migraphx::op::pooling{"max"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2}; op.lengths = {2};
op.padding = {0}; op.padding = {0};
op.stride = {2}; op.stride = {2};
...@@ -2647,7 +2648,7 @@ TEST_CASE(maxpool_test_1D_3D) ...@@ -2647,7 +2648,7 @@ TEST_CASE(maxpool_test_1D_3D)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}}; auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}};
auto op = migraphx::op::pooling{"max"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2}; op.lengths = {2};
op.padding = {0}; op.padding = {0};
op.stride = {2}; op.stride = {2};
...@@ -2683,7 +2684,7 @@ TEST_CASE(maxpool_test_1D_3D) ...@@ -2683,7 +2684,7 @@ TEST_CASE(maxpool_test_1D_3D)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}};
auto op = migraphx::op::pooling{"max"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2, 2, 2}; op.lengths = {2, 2, 2};
op.padding = {0, 0, 0}; op.padding = {0, 0, 0};
op.stride = {2, 2, 2}; op.stride = {2, 2, 2};
...@@ -4037,9 +4038,10 @@ TEST_CASE(roialign_out_of_bound_test) ...@@ -4037,9 +4038,10 @@ TEST_CASE(roialign_out_of_bound_test)
TEST_CASE(roialign_test) TEST_CASE(roialign_test)
{ {
auto create_program = [](const std::string& trans_mode = "half_pixel", auto create_program = [](const std::string& trans_mode = "half_pixel",
const std::string& pooling_mode = "avg", const migraphx::op::pooling_mode pooling_mode =
int64_t sampling_ratio = 2) { migraphx::op::pooling_mode::average,
int64_t sampling_ratio = 2) {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape x_s{migraphx::shape::float_type, {1, 1, 10, 10}}; migraphx::shape x_s{migraphx::shape::float_type, {1, 1, 10, 10}};
...@@ -4125,7 +4127,7 @@ TEST_CASE(roialign_test) ...@@ -4125,7 +4127,7 @@ TEST_CASE(roialign_test)
} }
{ {
auto p = create_program("output_half_pixel", "max", 0); auto p = create_program("output_half_pixel", migraphx::op::pooling_mode::max, 0);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
......
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
...@@ -22,7 +23,7 @@ static void opt_pooling(migraphx::module& m) ...@@ -22,7 +23,7 @@ static void opt_pooling(migraphx::module& m)
TEST_CASE(rewrite_pooling_test) TEST_CASE(rewrite_pooling_test)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}}; migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
auto pooling_program = [&](const std::string& mode) { auto pooling_program = [&](const migraphx::op::pooling_mode mode) {
migraphx::module m; migraphx::module m;
auto input = m.add_parameter("x", s); auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling", auto ret = m.add_instruction(migraphx::make_op("pooling",
...@@ -46,15 +47,16 @@ TEST_CASE(rewrite_pooling_test) ...@@ -46,15 +47,16 @@ TEST_CASE(rewrite_pooling_test)
return m; return m;
}; };
auto test_rewrite = [&](const std::string& mode, const migraphx::operation& op) { auto test_rewrite = [&](const migraphx::op::pooling_mode mode, const migraphx::operation& op) {
migraphx::module m1 = pooling_program(mode); migraphx::module m1 = pooling_program(mode);
migraphx::module m2 = opt_program(op); migraphx::module m2 = opt_program(op);
opt_pooling(m1); opt_pooling(m1);
EXPECT(m1 == m2); EXPECT(m1 == m2);
}; };
test_rewrite("average", migraphx::make_op("reduce_mean", {{"axes", {1}}})); test_rewrite(migraphx::op::pooling_mode::average,
test_rewrite("max", migraphx::make_op("reduce_max", {{"axes", {1}}})); migraphx::make_op("reduce_mean", {{"axes", {1}}}));
test_rewrite(migraphx::op::pooling_mode::max, migraphx::make_op("reduce_max", {{"axes", {1}}}));
} }
TEST_CASE(rewrite_avepooling_na1_test) TEST_CASE(rewrite_avepooling_na1_test)
...@@ -64,12 +66,13 @@ TEST_CASE(rewrite_avepooling_na1_test) ...@@ -64,12 +66,13 @@ TEST_CASE(rewrite_avepooling_na1_test)
migraphx::module m; migraphx::module m;
auto input = m.add_parameter("x", s); auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling", auto ret =
{{"mode", "average"}, m.add_instruction(migraphx::make_op("pooling",
{"padding", {0, 1, 0}}, {{"mode", migraphx::op::pooling_mode::average},
{"stride", {1, 1, 1}}, {"padding", {0, 1, 0}},
{"lengths", {3, 4, 5}}}), {"stride", {1, 1, 1}},
input); {"lengths", {3, 4, 5}}}),
input);
m.add_return({ret}); m.add_return({ret});
return m; return m;
}; };
...@@ -88,12 +91,13 @@ TEST_CASE(rewrite_avepooling_na2_test) ...@@ -88,12 +91,13 @@ TEST_CASE(rewrite_avepooling_na2_test)
migraphx::module m; migraphx::module m;
auto input = m.add_parameter("x", s); auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling", auto ret =
{{"mode", "average"}, m.add_instruction(migraphx::make_op("pooling",
{"padding", {0, 0, 0}}, {{"mode", migraphx::op::pooling_mode::average},
{"stride", {1, 2, 1}}, {"padding", {0, 0, 0}},
{"lengths", {3, 4, 5}}}), {"stride", {1, 2, 1}},
input); {"lengths", {3, 4, 5}}}),
input);
m.add_return({ret}); m.add_return({ret});
return m; return m;
}; };
...@@ -113,7 +117,7 @@ TEST_CASE(rewrite_avepooling_na3_test) ...@@ -113,7 +117,7 @@ TEST_CASE(rewrite_avepooling_na3_test)
auto input = m.add_parameter("x", s); auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling", auto ret = m.add_instruction(migraphx::make_op("pooling",
{{"mode", "max"}, {{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0, 0}}, {"padding", {0, 0, 0}},
{"stride", {1, 1, 1}}, {"stride", {1, 1, 1}},
{"lengths", {3, 3, 5}}}), {"lengths", {3, 3, 5}}}),
...@@ -135,7 +139,7 @@ TEST_CASE(literal_rewrite_pooling_test) ...@@ -135,7 +139,7 @@ TEST_CASE(literal_rewrite_pooling_test)
std::vector<float> data(s.elements()); std::vector<float> data(s.elements());
std::iota(data.begin(), data.end(), 1.0f); std::iota(data.begin(), data.end(), 1.0f);
auto pooling_program = [&](const std::string& mode) { auto pooling_program = [&](const migraphx::op::pooling_mode mode) {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -163,7 +167,8 @@ TEST_CASE(literal_rewrite_pooling_test) ...@@ -163,7 +167,8 @@ TEST_CASE(literal_rewrite_pooling_test)
return p; return p;
}; };
auto test_rewrite_pooling = [&](const std::string& mode, const migraphx::operation& op) { auto test_rewrite_pooling = [&](const migraphx::op::pooling_mode mode,
const migraphx::operation& op) {
migraphx::program p1 = pooling_program(mode); migraphx::program p1 = pooling_program(mode);
migraphx::program p2 = opt_program(op); migraphx::program p2 = opt_program(op);
p1.compile(migraphx::ref::target{}); p1.compile(migraphx::ref::target{});
...@@ -174,8 +179,10 @@ TEST_CASE(literal_rewrite_pooling_test) ...@@ -174,8 +179,10 @@ TEST_CASE(literal_rewrite_pooling_test)
result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); }); result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
}; };
test_rewrite_pooling("max", migraphx::make_op("reduce_max", {{"axes", {1}}})); test_rewrite_pooling(migraphx::op::pooling_mode::max,
test_rewrite_pooling("average", migraphx::make_op("reduce_mean", {{"axes", {1}}})); migraphx::make_op("reduce_max", {{"axes", {1}}}));
test_rewrite_pooling(migraphx::op::pooling_mode::average,
migraphx::make_op("reduce_mean", {{"axes", {1}}}));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <test.hpp> #include <test.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
...@@ -462,14 +463,15 @@ TEST_CASE(conv_pooling_dot) ...@@ -462,14 +463,15 @@ TEST_CASE(conv_pooling_dot)
d1); d1);
auto bc1 = m1.add_instruction( auto bc1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1); auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1);
auto ap = m1.add_instruction(migraphx::make_op("pooling", auto ap =
{{"mode", "average"}, m1.add_instruction(migraphx::make_op("pooling",
{"padding", {0, 0, 0, 0}}, {{"mode", migraphx::op::pooling_mode::average},
{"stride", {1, 1}}, {"padding", {0, 0, 0, 0}},
{"lengths", {7, 7}}, {"stride", {1, 1}},
{"ceil_mode", 0}}), {"lengths", {7, 7}},
a1); {"ceil_mode", 0}}),
a1);
auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap); auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero); auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero);
auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero); auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero);
...@@ -508,14 +510,15 @@ TEST_CASE(conv_pooling_dot) ...@@ -508,14 +510,15 @@ TEST_CASE(conv_pooling_dot)
auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1); auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
auto bc1 = m2.add_instruction( auto bc1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1); auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1);
auto ap = m2.add_instruction(migraphx::make_op("pooling", auto ap =
{{"mode", "average"}, m2.add_instruction(migraphx::make_op("pooling",
{"padding", {0, 0, 0, 0}}, {{"mode", migraphx::op::pooling_mode::average},
{"stride", {1, 1}}, {"padding", {0, 0, 0, 0}},
{"lengths", {7, 7}}, {"stride", {1, 1}},
{"ceil_mode", 0}}), {"lengths", {7, 7}},
a1); {"ceil_mode", 0}}),
a1);
auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap); auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero); auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db); auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db);
...@@ -564,16 +567,17 @@ TEST_CASE(mobilenet_snippet) ...@@ -564,16 +567,17 @@ TEST_CASE(mobilenet_snippet)
d1); d1);
auto bc1 = mm.add_instruction( auto bc1 = mm.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1); auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1);
auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero); auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero);
auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero); auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero);
auto ap = mm.add_instruction(migraphx::make_op("pooling", auto ap =
{{"mode", "average"}, mm.add_instruction(migraphx::make_op("pooling",
{"padding", {0, 0, 0, 0}}, {{"mode", migraphx::op::pooling_mode::average},
{"stride", {1, 1}}, {"padding", {0, 0, 0, 0}},
{"lengths", {7, 7}}, {"stride", {1, 1}},
{"ceil_mode", 0}}), {"lengths", {7, 7}},
d6); {"ceil_mode", 0}}),
d6);
auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero); auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero);
auto d7 = add_quantize_op(mm, "dequantizelinear", q3, scale, zero); auto d7 = add_quantize_op(mm, "dequantizelinear", q3, scale, zero);
auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7); auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7);
......
...@@ -649,8 +649,8 @@ TEST_CASE(pooling_test) ...@@ -649,8 +649,8 @@ TEST_CASE(pooling_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::op::pooling avg_pool_op{"average"}; migraphx::op::pooling avg_pool_op{migraphx::op::pooling_mode::average};
migraphx::op::pooling max_pool_op{"max"}; migraphx::op::pooling max_pool_op{migraphx::op::pooling_mode::max};
avg_pool_op.stride = {2, 2}; avg_pool_op.stride = {2, 2};
max_pool_op.stride = {2, 2}; max_pool_op.stride = {2, 2};
avg_pool_op.lengths = {2, 2}; avg_pool_op.lengths = {2, 2};
......
...@@ -12,7 +12,7 @@ struct test_avg_pooling_1d : verify_program<test_avg_pooling_1d> ...@@ -12,7 +12,7 @@ struct test_avg_pooling_1d : verify_program<test_avg_pooling_1d>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}});
auto op = migraphx::op::pooling{"average", {0}, {1}, {3}}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average, {0}, {1}, {3}};
mm->add_instruction(op, input); mm->add_instruction(op, input);
return p; return p;
} }
......
...@@ -12,7 +12,8 @@ struct test_avg_pooling_3d : verify_program<test_avg_pooling_3d> ...@@ -12,7 +12,8 @@ struct test_avg_pooling_3d : verify_program<test_avg_pooling_3d>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}});
auto op = migraphx::op::pooling{"average", {1, 1, 1}, {3, 3, 3}, {3, 3, 3}}; auto op = migraphx::op::pooling{
migraphx::op::pooling_mode::average, {1, 1, 1}, {3, 3, 3}, {3, 3, 3}};
mm->add_instruction(op, input); mm->add_instruction(op, input);
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment