Commit fae5170b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into ref_op_name

parents a3906038 2a79a9ff
!randomuniform_generated_seed_test:
2
inputoutput" RandomUniform*
sample_size
!randomuniform_generated_seed_testZ
input



b
output



B
\ No newline at end of file
softplus_nd_test:V

xy"Softplussoftplus_nd_testZ
x




b
y




B
\ No newline at end of file
 softplus_test:C

xy"Softplus softplus_testZ
x

b
y

B
\ No newline at end of file
softsign_nd_test:V

xy"Softsignsoftsign_nd_testZ
x




b
y




B
\ No newline at end of file
 softsign_test:C

xy"Softsign softsign_testZ
x

b
y

B
\ No newline at end of file
...@@ -126,6 +126,51 @@ TEST_CASE(gather_elements) ...@@ -126,6 +126,51 @@ TEST_CASE(gather_elements)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(greaterorequal_test)
{
migraphx::program p = migraphx::parse_onnx("greaterorequal_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {3}};
std::vector<float> data1 = {0.25, 0.75, 0.9375};
std::vector<float> data2 = {0.25, 0.74, 0.9411};
migraphx::parameter_map pp;
pp["x1"] = migraphx::argument(s, data1.data());
pp["x2"] = migraphx::argument(s, data2.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, 0.0};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(hardsigmoid_verify_test)
{
migraphx::program p = migraphx::parse_onnx("hardsigmoid_verify_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {2, 5}};
std::vector<float> data = {-10.0, -2.5, -1.0, -0.5, 0, 1.0, 2.0, 2.5, 2.6, 100.0};
float alpha = 0.2;
float beta = 0.5;
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(10);
std::transform(data.begin(), data.end(), gold.begin(), [&](auto x) {
return std::max(0.0f, std::min(x * alpha + beta, 1.0f));
});
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_else_test) TEST_CASE(if_else_test)
{ {
migraphx::program p = migraphx::parse_onnx("if_else_test.onnx"); migraphx::program p = migraphx::parse_onnx("if_else_test.onnx");
...@@ -348,6 +393,64 @@ TEST_CASE(lessorequal_test) ...@@ -348,6 +393,64 @@ TEST_CASE(lessorequal_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(mean_broadcast_test)
{
migraphx::program p = migraphx::parse_onnx("mean_broadcast_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s0{migraphx::shape::float_type, {1, 3, 4}};
std::vector<float> data0(12, 1);
migraphx::shape s1{migraphx::shape::float_type, {1, 2, 3, 4}};
std::vector<float> data1(24, 2);
migraphx::shape s2{migraphx::shape::float_type, {4}};
std::vector<float> data2(4, 3);
migraphx::shape s3{migraphx::shape::float_type, {1}};
std::vector<float> data3(1, 4);
migraphx::shape s4{migraphx::shape::float_type, {2, 3, 1}};
std::vector<float> data4(6, 5);
migraphx::parameter_map pp;
pp["0"] = migraphx::argument(s0, data0.data());
pp["1"] = migraphx::argument(s1, data1.data());
pp["2"] = migraphx::argument(s2, data2.data());
pp["3"] = migraphx::argument(s3, data3.data());
pp["4"] = migraphx::argument(s4, data4.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(24, 3);
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(mean_test)
{
migraphx::program p = migraphx::parse_onnx("mean_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::double_type, {2, 2, 2}};
const int num_elms = 8;
const int num_data = 10;
const std::vector<double> scalars{1.0, 2.0, -2.5, 3.3, 10.7, -1.0, 100.0, 7.9, 0.01, -56.8};
std::vector<std::vector<double>> data;
std::transform(scalars.begin(), scalars.end(), std::back_inserter(data), [&](const auto& i) {
return std::vector<double>(num_elms, i);
});
migraphx::parameter_map pp;
for(std::size_t i = 0; i < num_data; ++i)
pp[std::to_string(i)] = migraphx::argument(s, data[i].data());
auto result = p.eval(pp).back();
std::vector<double> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
const auto mean = std::accumulate(scalars.begin(), scalars.end(), 0.0) / num_data;
std::vector<double> gold(num_elms, mean);
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(nonzero_test) TEST_CASE(nonzero_test)
{ {
migraphx::program p = migraphx::parse_onnx("nonzero_dynamic_test.onnx"); migraphx::program p = migraphx::parse_onnx("nonzero_dynamic_test.onnx");
...@@ -564,6 +667,48 @@ TEST_CASE(slice_step_test) ...@@ -564,6 +667,48 @@ TEST_CASE(slice_step_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(softplus_test)
{
migraphx::program p = migraphx::parse_onnx("softplus_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {5}};
std::vector<float> data = {0, 1, 2, 3, 4};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(5);
std::transform(
data.begin(), data.end(), gold.begin(), [](auto x) { return std::log1p(std::exp(x)); });
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(softsign_test)
{
migraphx::program p = migraphx::parse_onnx("softsign_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {5}};
std::vector<float> data = {0, 1, 2, 3, 4};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(5);
std::transform(
data.begin(), data.end(), gold.begin(), [](auto x) { return x / (1.0 + std::abs(x)); });
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(upsample_test) TEST_CASE(upsample_test)
{ {
migraphx::program p = migraphx::parse_onnx("upsample_test.onnx"); migraphx::program p = migraphx::parse_onnx("upsample_test.onnx");
......
...@@ -119,6 +119,7 @@ def create_backend_test(testname=None, target_device=None): ...@@ -119,6 +119,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_globalmaxpool.*') backend_test.include(r'.*test_globalmaxpool.*')
backend_test.include(r'.*test_greater.*') backend_test.include(r'.*test_greater.*')
backend_test.include(r'.*test_hardsigmoid.*') backend_test.include(r'.*test_hardsigmoid.*')
backend_test.include(r'.*test_hardswish.*')
backend_test.include(r'.*test_identity.*') backend_test.include(r'.*test_identity.*')
backend_test.include(r'.*test_if.*') backend_test.include(r'.*test_if.*')
backend_test.include(r'.*test_LeakyReLU*') backend_test.include(r'.*test_LeakyReLU*')
...@@ -266,18 +267,8 @@ def create_backend_test(testname=None, target_device=None): ...@@ -266,18 +267,8 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_gathernd_example_float32_cpu') backend_test.exclude(r'test_gathernd_example_float32_cpu')
backend_test.exclude(r'test_gathernd_example_int32_batch_dim1_cpu') backend_test.exclude(r'test_gathernd_example_int32_batch_dim1_cpu')
backend_test.exclude(r'test_gathernd_example_int32_cpu') backend_test.exclude(r'test_gathernd_example_int32_cpu')
backend_test.exclude(r'test_greater_equal_bcast_cpu')
backend_test.exclude(r'test_greater_equal_bcast_expanded_cpu')
backend_test.exclude(r'test_greater_equal_cpu')
backend_test.exclude(r'test_greater_equal_expanded_cpu')
backend_test.exclude(r'test_hardsigmoid_cpu')
backend_test.exclude(r'test_hardsigmoid_default_cpu')
backend_test.exclude(r'test_hardsigmoid_example_cpu')
backend_test.exclude(r'test_identity_sequence_cpu') backend_test.exclude(r'test_identity_sequence_cpu')
backend_test.exclude(r'test_maxpool_2d_uint8_cpu') backend_test.exclude(r'test_maxpool_2d_uint8_cpu')
backend_test.exclude(r'test_mean_example_cpu')
backend_test.exclude(r'test_mean_one_input_cpu')
backend_test.exclude(r'test_mean_two_inputs_cpu')
backend_test.exclude(r'test_negative_log_likelihood_loss_*') backend_test.exclude(r'test_negative_log_likelihood_loss_*')
backend_test.exclude(r'test_scatternd_*') backend_test.exclude(r'test_scatternd_*')
...@@ -285,12 +276,7 @@ def create_backend_test(testname=None, target_device=None): ...@@ -285,12 +276,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_size_cpu') backend_test.exclude(r'test_size_cpu')
backend_test.exclude(r'test_size_example_cpu') backend_test.exclude(r'test_size_example_cpu')
backend_test.exclude(r'test_softmax_cross_entropy_*') backend_test.exclude(r'test_softmax_cross_entropy_*')
backend_test.exclude(r'test_softplus_cpu')
backend_test.exclude(r'test_softplus_example_cpu')
backend_test.exclude(r'test_softsign_cpu')
backend_test.exclude(r'test_softsign_example_cpu')
backend_test.exclude(r'test_Embedding_cpu') backend_test.exclude(r'test_Embedding_cpu')
backend_test.exclude(r'test_Softplus_cpu')
# real model tests # real model tests
backend_test.exclude(r'test_inception_v1_cpu') backend_test.exclude(r'test_inception_v1_cpu')
......
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/pass_manager.hpp>
#include "test.hpp"
TEST_CASE(argmax_test_nonstd_shape)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
auto dl_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl);
mm->add_instruction(migraphx::make_op("argmax", {{"axis", -3}}), dl_trans);
auto p_uncompiled = p;
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto res_gold = p_uncompiled.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold_vec));
}
TEST_CASE(argmin_test_nonstd_shape)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
auto dl_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl);
mm->add_instruction(migraphx::make_op("argmin", {{"axis", -1}}), dl_trans);
auto p_uncompiled = p;
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto res_gold = p_uncompiled.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold_vec));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1473,6 +1473,32 @@ TEST_CASE(fp32_fp16_test) ...@@ -1473,6 +1473,32 @@ TEST_CASE(fp32_fp16_test)
test_case({"add"}); test_case({"add"});
} }
TEST_CASE(gather_non_std_test)
{
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {0.5f, 3.5f, 6.5f, 1.5f, 4.5f, 7.5f, 2.5f, 2.5f, 8.5f};
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto d = mm->add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{-3, -3, -1, -1};
auto ind = mm->add_literal(migraphx::literal{s_indices, indices});
auto td = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d);
auto tind =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), ind);
mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), td, tind);
auto result = p.eval({}).back();
std::vector<float> golden = {
0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f, 0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f};
std::vector<float> res_data;
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
}
TEST_CASE(gather_test) TEST_CASE(gather_test)
{ {
{ {
...@@ -2784,7 +2810,6 @@ TEST_CASE(nms_not_center_test) ...@@ -2784,7 +2810,6 @@ TEST_CASE(nms_not_center_test)
auto output = p.eval({}).back(); auto output = p.eval({}).back();
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::cout << "output = " << output << std::endl;
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(result, gold)); EXPECT(migraphx::verify_range(result, gold));
} }
...@@ -2818,7 +2843,6 @@ TEST_CASE(nms_test) ...@@ -2818,7 +2843,6 @@ TEST_CASE(nms_test)
auto output = p.eval({}).back(); auto output = p.eval({}).back();
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::cout << "output = " << output << std::endl;
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(result, gold)); EXPECT(migraphx::verify_range(result, gold));
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/verify_args.hpp> #include <migraphx/verify_args.hpp>
#include <set> #include <set>
...@@ -15,6 +16,7 @@ ...@@ -15,6 +16,7 @@
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_TEST_COMPILE) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_TEST_COMPILE)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_TEST) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_TEST)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DUMP_TEST)
// An improved async, that doesn't block // An improved async, that doesn't block
template <class Function> template <class Function>
...@@ -125,6 +127,8 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con ...@@ -125,6 +127,8 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con
using result_future = using result_future =
std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>; std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>;
auto_print::set_terminate_handler(name); auto_print::set_terminate_handler(name);
if(migraphx::enabled(MIGRAPHX_DUMP_TEST{}))
migraphx::save(p, name + ".mx");
std::vector<std::pair<std::string, result_future>> results; std::vector<std::pair<std::string, result_future>> results;
std::vector<std::string> target_names; std::vector<std::string> target_names;
for(const auto& tname : migraphx::get_targets()) for(const auto& tname : migraphx::get_targets())
......
...@@ -2,34 +2,92 @@ ...@@ -2,34 +2,92 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/argmax.hpp> #include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp> #include <migraphx/op/argmin.hpp>
template <class T, int Axis> template <class T, int Axis, int NonStdShape>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis>> struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 1025}}; migraphx::shape s{migraphx::shape::float_type, {2, 1, 4, 1025}};
auto param = mm->add_parameter("data", s); auto param = mm->add_parameter("data", s);
switch(NonStdShape)
{
case 0:
param = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), param);
break;
case 1:
param = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 1025}}}), param);
break;
case 2:
param = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}), param);
break;
default: break;
}
mm->add_instruction(T{Axis}, param); mm->add_instruction(T{Axis}, param);
return p; return p;
} }
}; };
// transpose argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0>; template struct test_arg_ops<migraphx::op::argmax, 0, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1>; template struct test_arg_ops<migraphx::op::argmax, 1, 0>;
template struct test_arg_ops<migraphx::op::argmax, 2>; template struct test_arg_ops<migraphx::op::argmax, 2, 0>;
template struct test_arg_ops<migraphx::op::argmax, 3>; template struct test_arg_ops<migraphx::op::argmax, 3, 0>;
template struct test_arg_ops<migraphx::op::argmax, -1>; template struct test_arg_ops<migraphx::op::argmax, -1, 0>;
template struct test_arg_ops<migraphx::op::argmax, -2>; template struct test_arg_ops<migraphx::op::argmax, -2, 0>;
// transpose argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0>; template struct test_arg_ops<migraphx::op::argmin, 0, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1>; template struct test_arg_ops<migraphx::op::argmin, 1, 0>;
template struct test_arg_ops<migraphx::op::argmin, 2>; template struct test_arg_ops<migraphx::op::argmin, 2, 0>;
template struct test_arg_ops<migraphx::op::argmin, 3>; template struct test_arg_ops<migraphx::op::argmin, 3, 0>;
template struct test_arg_ops<migraphx::op::argmin, -3>; template struct test_arg_ops<migraphx::op::argmin, -3, 0>;
template struct test_arg_ops<migraphx::op::argmin, -4>; template struct test_arg_ops<migraphx::op::argmin, -4, 0>;
// broadcast argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 1>;
template struct test_arg_ops<migraphx::op::argmax, 1, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2, 1>;
template struct test_arg_ops<migraphx::op::argmax, 3, 1>;
template struct test_arg_ops<migraphx::op::argmax, -1, 1>;
template struct test_arg_ops<migraphx::op::argmax, -2, 1>;
// broadcast argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 1>;
template struct test_arg_ops<migraphx::op::argmin, 1, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2, 1>;
template struct test_arg_ops<migraphx::op::argmin, 3, 1>;
template struct test_arg_ops<migraphx::op::argmin, -3, 1>;
template struct test_arg_ops<migraphx::op::argmin, -4, 1>;
// slice argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 2>;
template struct test_arg_ops<migraphx::op::argmax, 1, 2>;
template struct test_arg_ops<migraphx::op::argmax, 2, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3, 2>;
template struct test_arg_ops<migraphx::op::argmax, -1, 2>;
template struct test_arg_ops<migraphx::op::argmax, -2, 2>;
// slice argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 2>;
template struct test_arg_ops<migraphx::op::argmin, 1, 2>;
template struct test_arg_ops<migraphx::op::argmin, 2, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3, 2>;
template struct test_arg_ops<migraphx::op::argmin, -3, 2>;
template struct test_arg_ops<migraphx::op::argmin, -4, 2>;
// default case, standard shape argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 3>;
template struct test_arg_ops<migraphx::op::argmax, 1, 3>;
template struct test_arg_ops<migraphx::op::argmax, 2, 3>;
template struct test_arg_ops<migraphx::op::argmax, 3, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1, 3>;
template struct test_arg_ops<migraphx::op::argmax, -2, 3>;
// default case, standard shape argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 3>;
template struct test_arg_ops<migraphx::op::argmin, 1, 3>;
template struct test_arg_ops<migraphx::op::argmin, 2, 3>;
template struct test_arg_ops<migraphx::op::argmin, 3, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3, 3>;
template struct test_arg_ops<migraphx::op::argmin, -4, 3>;
...@@ -28,9 +28,9 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu> ...@@ -28,9 +28,9 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu>
auto min_val = mm->add_literal(0.0f); auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f); auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction( min_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), min_val); migraphx::make_op("multibroadcast", {{"out_lens", conv->get_shape().lens()}}), min_val);
max_val = mm->add_instruction( max_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), max_val); migraphx::make_op("multibroadcast", {{"out_lens", conv->get_shape().lens()}}), max_val);
mm->add_instruction(migraphx::make_op("clip"), bias_add, min_val, max_val); mm->add_instruction(migraphx::make_op("clip"), bias_add, min_val, max_val);
return p; return p;
} }
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_hsqrt : verify_program<test_hsqrt>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {2, 3, 4, 6}};
auto param = mm->add_parameter("x", s);
auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param);
mm->add_instruction(migraphx::make_op("sqrt"), param_abs);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_nonstd_gather : verify_program<test_nonstd_gather>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{1, 1, 0, 2};
auto d = mm->add_parameter("data", s);
auto td = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d);
auto ind = mm->add_literal(migraphx::literal{s_indices, indices});
auto tind =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), ind);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), td, tind);
mm->add_return({r});
return p;
}
};
import string, sys, re, runpy import string, sys, re, runpy
from functools import wraps from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
type_map = {} type_map: Dict[str, Callable[['Parameter'], None]] = {}
cpp_type_map = {} cpp_type_map: Dict[str, str] = {}
functions = [] functions: List['Function'] = []
cpp_classes = [] cpp_classes: List['CPPClass'] = []
error_type = '' error_type = ''
success_type = '' success_type = ''
try_wrap = '' try_wrap = ''
c_header_preamble = [] c_header_preamble: List[str] = []
c_api_body_preamble = [] c_api_body_preamble: List[str] = []
cpp_header_preamble = [] cpp_header_preamble: List[str] = []
def bad_param_error(msg): def bad_param_error(msg):
...@@ -23,31 +24,31 @@ class Template(string.Template): ...@@ -23,31 +24,31 @@ class Template(string.Template):
class Type: class Type:
def __init__(self, name): def __init__(self, name: str) -> None:
self.name = name.strip() self.name = name.strip()
def is_pointer(self): def is_pointer(self) -> bool:
return self.name.endswith('*') return self.name.endswith('*')
def is_reference(self): def is_reference(self) -> bool:
return self.name.endswith('&') return self.name.endswith('&')
def is_const(self): def is_const(self) -> bool:
return self.name.startswith('const ') return self.name.startswith('const ')
def is_variadic(self): def is_variadic(self):
return self.name.startswith('...') return self.name.startswith('...')
def add_pointer(self): def add_pointer(self) -> 'Type':
return Type(self.name + '*') return Type(self.name + '*')
def add_reference(self): def add_reference(self):
return Type(self.name + '&') return Type(self.name + '&')
def add_const(self): def add_const(self) -> 'Type':
return Type('const ' + self.name) return Type('const ' + self.name)
def inner_type(self): def inner_type(self) -> Optional['Type']:
i = self.name.find('<') i = self.name.find('<')
j = self.name.rfind('>') j = self.name.rfind('>')
if i > 0 and j > 0: if i > 0 and j > 0:
...@@ -55,7 +56,7 @@ class Type: ...@@ -55,7 +56,7 @@ class Type:
else: else:
return None return None
def remove_generic(self): def remove_generic(self) -> 'Type':
i = self.name.find('<') i = self.name.find('<')
j = self.name.rfind('>') j = self.name.rfind('>')
if i > 0 and j > 0: if i > 0 and j > 0:
...@@ -63,25 +64,25 @@ class Type: ...@@ -63,25 +64,25 @@ class Type:
else: else:
return self return self
def remove_pointer(self): def remove_pointer(self) -> 'Type':
if self.is_pointer(): if self.is_pointer():
return Type(self.name[0:-1]) return Type(self.name[0:-1])
return self return self
def remove_reference(self): def remove_reference(self) -> 'Type':
if self.is_reference(): if self.is_reference():
return Type(self.name[0:-1]) return Type(self.name[0:-1])
return self return self
def remove_const(self): def remove_const(self) -> 'Type':
if self.is_const(): if self.is_const():
return Type(self.name[6:]) return Type(self.name[6:])
return self return self
def basic(self): def basic(self) -> 'Type':
return self.remove_pointer().remove_const().remove_reference() return self.remove_pointer().remove_const().remove_reference()
def decay(self): def decay(self) -> 'Type':
t = self.remove_reference() t = self.remove_reference()
if t.is_pointer(): if t.is_pointer():
return t return t
...@@ -93,7 +94,7 @@ class Type: ...@@ -93,7 +94,7 @@ class Type:
return self.add_const() return self.add_const()
return self return self
def str(self): def str(self) -> str:
return self.name return self.name
...@@ -113,20 +114,20 @@ extern "C" ${error_type} ${name}(${params}) ...@@ -113,20 +114,20 @@ extern "C" ${error_type} ${name}(${params})
class CFunction: class CFunction:
def __init__(self, name): def __init__(self, name: str) -> None:
self.name = name self.name = name
self.params = [] self.params: List[str] = []
self.body = [] self.body: List[str] = []
self.va_start = [] self.va_start: List[str] = []
self.va_end = [] self.va_end: List[str] = []
def add_param(self, type, pname): def add_param(self, type: str, pname: str) -> None:
self.params.append('{} {}'.format(type, pname)) self.params.append('{} {}'.format(type, pname))
def add_statement(self, stmt): def add_statement(self, stmt: str) -> None:
self.body.append(stmt) self.body.append(stmt)
def add_vlist(self, name): def add_vlist(self, name: str) -> None:
last_param = self.params[-1].split()[-1] last_param = self.params[-1].split()[-1]
self.va_start = [ self.va_start = [
'va_list {};'.format(name), 'va_list {};'.format(name),
...@@ -135,7 +136,7 @@ class CFunction: ...@@ -135,7 +136,7 @@ class CFunction:
self.va_end = ['va_end({});'.format(name)] self.va_end = ['va_end({});'.format(name)]
self.add_param('...', '') self.add_param('...', '')
def substitute(self, form): def substitute(self, form: Template) -> str:
return form.substitute(error_type=error_type, return form.substitute(error_type=error_type,
try_wrap=try_wrap, try_wrap=try_wrap,
name=self.name, name=self.name,
...@@ -144,25 +145,29 @@ class CFunction: ...@@ -144,25 +145,29 @@ class CFunction:
va_start="\n ".join(self.va_start), va_start="\n ".join(self.va_start),
va_end="\n ".join(self.va_end)) va_end="\n ".join(self.va_end))
def generate_header(self): def generate_header(self) -> str:
return self.substitute(header_function) return self.substitute(header_function)
def generate_body(self): def generate_body(self) -> str:
return self.substitute(c_api_impl) return self.substitute(c_api_impl)
class BadParam: class BadParam:
def __init__(self, cond, msg): def __init__(self, cond: str, msg: str) -> None:
self.cond = cond self.cond = cond
self.msg = msg self.msg = msg
class Parameter: class Parameter:
def __init__(self, name, type, optional=False, returns=False): def __init__(self,
name: str,
type: str,
optional: bool = False,
returns: bool = False) -> None:
self.name = name self.name = name
self.type = Type(type) self.type = Type(type)
self.optional = optional self.optional = optional
self.cparams = [] self.cparams: List[Tuple[str, str]] = []
self.size_cparam = -1 self.size_cparam = -1
self.size_name = '' self.size_name = ''
self.read = '${name}' self.read = '${name}'
...@@ -170,15 +175,15 @@ class Parameter: ...@@ -170,15 +175,15 @@ class Parameter:
self.cpp_read = '${name}' self.cpp_read = '${name}'
self.cpp_write = '${name}' self.cpp_write = '${name}'
self.returns = returns self.returns = returns
self.bad_param_check = None self.bad_param_check: Optional[BadParam] = None
def get_name(self, prefix=None): def get_name(self, prefix: Optional[str] = None) -> str:
if prefix: if prefix:
return prefix + self.name return prefix + self.name
else: else:
return self.name return self.name
def get_cpp_type(self): def get_cpp_type(self) -> str:
if self.type.str() in cpp_type_map: if self.type.str() in cpp_type_map:
return cpp_type_map[self.type.basic().str()] return cpp_type_map[self.type.basic().str()]
elif self.type.basic().str() in cpp_type_map: elif self.type.basic().str() in cpp_type_map:
...@@ -188,7 +193,10 @@ class Parameter: ...@@ -188,7 +193,10 @@ class Parameter:
else: else:
return self.type.str() return self.type.str()
def substitute(self, s, prefix=None, result=None): def substitute(self,
s: str,
prefix: Optional[str] = None,
result: Optional[str] = None) -> str:
ctype = None ctype = None
if len(self.cparams) > 0: if len(self.cparams) > 0:
ctype = Type(self.cparams[0][0]).basic().str() ctype = Type(self.cparams[0][0]).basic().str()
...@@ -199,12 +207,13 @@ class Parameter: ...@@ -199,12 +207,13 @@ class Parameter:
size=self.size_name, size=self.size_name,
result=result or '') result=result or '')
def add_param(self, t, name=None): def add_param(self, t: Union[str, Type],
name: Optional[str] = None) -> None:
if not isinstance(t, str): if not isinstance(t, str):
t = t.str() t = t.str()
self.cparams.append((t, name or self.name)) self.cparams.append((t, name or self.name))
def add_size_param(self, name=None): def add_size_param(self, name: Optional[str] = None) -> None:
self.size_cparam = len(self.cparams) self.size_cparam = len(self.cparams)
self.size_name = name or self.name + '_size' self.size_name = name or self.name + '_size'
if self.returns: if self.returns:
...@@ -212,7 +221,7 @@ class Parameter: ...@@ -212,7 +221,7 @@ class Parameter:
else: else:
self.add_param('size_t', self.size_name) self.add_param('size_t', self.size_name)
def bad_param(self, cond, msg): def bad_param(self, cond: str, msg: str) -> None:
self.bad_param_check = BadParam(cond, msg) self.bad_param_check = BadParam(cond, msg)
def remove_size_param(self, name): def remove_size_param(self, name):
...@@ -223,7 +232,7 @@ class Parameter: ...@@ -223,7 +232,7 @@ class Parameter:
self.size_name = name self.size_name = name
return p return p
def update(self): def update(self) -> None:
t = self.type.basic().str() t = self.type.basic().str()
g = self.type.remove_generic().basic().str() g = self.type.remove_generic().basic().str()
if t in type_map: if t in type_map:
...@@ -239,18 +248,18 @@ class Parameter: ...@@ -239,18 +248,18 @@ class Parameter:
raise ValueError("Error for {}: write cannot be a string".format( raise ValueError("Error for {}: write cannot be a string".format(
self.type.str())) self.type.str()))
def cpp_param(self, prefix=None): def cpp_param(self, prefix: Optional[str] = None) -> str:
return self.substitute('${cpptype} ${name}', prefix=prefix) return self.substitute('${cpptype} ${name}', prefix=prefix)
def cpp_arg(self, prefix=None): def cpp_arg(self, prefix: Optional[str] = None) -> str:
return self.substitute(self.cpp_read, prefix=prefix) return self.substitute(self.cpp_read, prefix=prefix)
def cpp_output_args(self, prefix=None): def cpp_output_args(self, prefix: Optional[str] = None) -> List[str]:
return [ return [
'&{prefix}{n}'.format(prefix=prefix, n=n) for t, n in self.cparams '&{prefix}{n}'.format(prefix=prefix, n=n) for t, n in self.cparams
] ]
def output_declarations(self, prefix=None): def output_declarations(self, prefix: Optional[str] = None) -> List[str]:
return [ return [
'{type} {prefix}{n};'.format(type=Type(t).remove_pointer().str(), '{type} {prefix}{n};'.format(type=Type(t).remove_pointer().str(),
prefix=prefix, prefix=prefix,
...@@ -262,16 +271,16 @@ class Parameter: ...@@ -262,16 +271,16 @@ class Parameter:
'&{prefix}{n};'.format(prefix=prefix, n=n) for t, n in self.cparams '&{prefix}{n};'.format(prefix=prefix, n=n) for t, n in self.cparams
] ]
def cpp_output(self, prefix=None): def cpp_output(self, prefix: Optional[str] = None) -> str:
return self.substitute(self.cpp_write, prefix=prefix) return self.substitute(self.cpp_write, prefix=prefix)
def input(self, prefix=None): def input(self, prefix: Optional[str] = None) -> str:
return '(' + self.substitute(self.read, prefix=prefix) + ')' return '(' + self.substitute(self.read, prefix=prefix) + ')'
def outputs(self, result=None): def outputs(self, result: Optional[str] = None) -> List[str]:
return [self.substitute(w, result=result) for w in self.write] return [self.substitute(w, result=result) for w in self.write]
def add_to_cfunction(self, cfunction): def add_to_cfunction(self, cfunction: CFunction) -> None:
for t, name in self.cparams: for t, name in self.cparams:
if t.startswith('...'): if t.startswith('...'):
cfunction.add_vlist(name) cfunction.add_vlist(name)
...@@ -285,35 +294,35 @@ class Parameter: ...@@ -285,35 +294,35 @@ class Parameter:
body=bad_param_error(msg))) body=bad_param_error(msg)))
def template_var(s): def template_var(s: str) -> str:
return '${' + s + '}' return '${' + s + '}'
def to_template_vars(params): def to_template_vars(params: List[Union[Any, Parameter]]) -> str:
return ', '.join([template_var(p.name) for p in params]) return ', '.join([template_var(p.name) for p in params])
class Function: class Function:
def __init__(self, def __init__(self,
name, name: str,
params=None, params: Optional[List[Parameter]] = None,
shared_size=False, shared_size: bool = False,
returns=None, returns: Optional[str] = None,
invoke=None, invoke: Optional[str] = None,
fname=None, fname: Optional[str] = None,
return_name=None, return_name: Optional[str] = None,
**kwargs): **kwargs) -> None:
self.name = name self.name = name
self.params = params or [] self.params = params or []
self.shared_size = False self.shared_size = False
self.cfunction = None self.cfunction: Optional[CFunction] = None
self.fname = fname self.fname = fname
self.invoke = invoke or '${__fname__}($@)' self.invoke = invoke or '${__fname__}($@)'
self.return_name = return_name or 'out' self.return_name = return_name or 'out'
self.returns = Parameter(self.return_name, returns, self.returns = Parameter(self.return_name, returns,
returns=True) if returns else None returns=True) if returns else None
def share_params(self): def share_params(self) -> None:
if self.shared_size == True: if self.shared_size == True:
size_param_name = 'size' size_param_name = 'size'
size_type = Type('size_t') size_type = Type('size_t')
...@@ -323,7 +332,7 @@ class Function: ...@@ -323,7 +332,7 @@ class Function:
size_type = Type(p[0]) size_type = Type(p[0])
self.params.append(Parameter(size_param_name, size_type.str())) self.params.append(Parameter(size_param_name, size_type.str()))
def update(self): def update(self) -> None:
self.share_params() self.share_params()
for param in self.params: for param in self.params:
param.update() param.update()
...@@ -331,11 +340,12 @@ class Function: ...@@ -331,11 +340,12 @@ class Function:
self.returns.update() self.returns.update()
self.create_cfunction() self.create_cfunction()
def inputs(self): def inputs(self) -> str:
return ', '.join([p.input() for p in self.params]) return ', '.join([p.input() for p in self.params])
def input_map(self): # TODO: Shoule we remove Optional?
m = {} def input_map(self) -> Dict[str, Optional[str]]:
m: Dict[str, Optional[str]] = {}
for p in self.params: for p in self.params:
m[p.name] = p.input() m[p.name] = p.input()
m['return'] = self.return_name m['return'] = self.return_name
...@@ -343,14 +353,22 @@ class Function: ...@@ -343,14 +353,22 @@ class Function:
m['__fname__'] = self.fname m['__fname__'] = self.fname
return m return m
def get_invoke(self): def get_invoke(self) -> str:
return Template(self.invoke).safe_substitute(self.input_map()) return Template(self.invoke).safe_substitute(self.input_map())
def write_to_tmp_var(self): def write_to_tmp_var(self) -> bool:
if not self.returns:
return False
return len(self.returns.write) > 1 or self.returns.write[0].count( return len(self.returns.write) > 1 or self.returns.write[0].count(
'${result}') > 1 '${result}') > 1
def create_cfunction(self): def get_cfunction(self) -> CFunction:
if self.cfunction:
return self.cfunction
raise Exception(
"self.cfunction is None: self.update() needs to be called.")
def create_cfunction(self) -> None:
self.cfunction = CFunction(self.name) self.cfunction = CFunction(self.name)
# Add the return as a parameter # Add the return as a parameter
if self.returns: if self.returns:
...@@ -358,12 +376,12 @@ class Function: ...@@ -358,12 +376,12 @@ class Function:
# Add the input parameters # Add the input parameters
for param in self.params: for param in self.params:
param.add_to_cfunction(self.cfunction) param.add_to_cfunction(self.cfunction)
f = self.get_invoke() f: Optional[str] = self.get_invoke()
# Write the assignments # Write the assignments
assigns = [] assigns = []
if self.returns: if self.returns:
result = f result = f
if self.write_to_tmp_var(): if self.write_to_tmp_var() and f:
f = 'auto&& api_result = ' + f f = 'auto&& api_result = ' + f
result = 'api_result' result = 'api_result'
else: else:
...@@ -416,31 +434,37 @@ cpp_class_constructor_template = Template(''' ...@@ -416,31 +434,37 @@ cpp_class_constructor_template = Template('''
class CPPMember: class CPPMember:
def __init__(self, name, function, prefix, method=True): def __init__(self,
name: str,
function: Function,
prefix: str,
method: bool = True) -> None:
self.name = name self.name = name
self.function = function self.function = function
self.prefix = prefix self.prefix = prefix
self.method = method self.method = method
def get_function_params(self): def get_function_params(self) -> List[Union[Any, Parameter]]:
if self.method: if self.method:
return self.function.params[1:] return self.function.params[1:]
else: else:
return self.function.params return self.function.params
def get_args(self): def get_args(self) -> str:
output_args = [] output_args = []
if self.function.returns: if self.function.returns:
output_args = self.function.returns.cpp_output_args(self.prefix) output_args = self.function.returns.cpp_output_args(self.prefix)
if not self.function.cfunction:
raise Exception('self.function.update() must be called')
return ', '.join( return ', '.join(
['&{}'.format(self.function.cfunction.name)] + output_args + ['&{}'.format(self.function.cfunction.name)] + output_args +
[p.cpp_arg(self.prefix) for p in self.get_function_params()]) [p.cpp_arg(self.prefix) for p in self.get_function_params()])
def get_params(self): def get_params(self) -> str:
return ', '.join( return ', '.join(
[p.cpp_param(self.prefix) for p in self.get_function_params()]) [p.cpp_param(self.prefix) for p in self.get_function_params()])
def get_return_declarations(self): def get_return_declarations(self) -> str:
if self.function.returns: if self.function.returns:
return '\n '.join([ return '\n '.join([
d d
...@@ -452,7 +476,9 @@ class CPPMember: ...@@ -452,7 +476,9 @@ class CPPMember:
def get_result(self): def get_result(self):
return self.function.returns.input(self.prefix) return self.function.returns.input(self.prefix)
def generate_method(self): def generate_method(self) -> str:
if not self.function.cfunction:
raise Exception('self.function.update() must be called')
if self.function.returns: if self.function.returns:
return_type = self.function.returns.get_cpp_type() return_type = self.function.returns.get_cpp_type()
return cpp_class_method_template.safe_substitute( return cpp_class_method_template.safe_substitute(
...@@ -472,7 +498,9 @@ class CPPMember: ...@@ -472,7 +498,9 @@ class CPPMember:
args=self.get_args(), args=self.get_args(),
success=success_type) success=success_type)
def generate_constructor(self, name): def generate_constructor(self, name: str) -> str:
if not self.function.cfunction:
raise Exception('self.function.update() must be called')
return cpp_class_constructor_template.safe_substitute( return cpp_class_constructor_template.safe_substitute(
name=name, name=name,
cfunction=self.function.cfunction.name, cfunction=self.function.cfunction.name,
...@@ -482,98 +510,101 @@ class CPPMember: ...@@ -482,98 +510,101 @@ class CPPMember:
class CPPClass: class CPPClass:
def __init__(self, name, ctype): def __init__(self, name: str, ctype: str) -> None:
self.name = name self.name = name
self.ctype = ctype self.ctype = ctype
self.constructors = [] self.constructors: List[CPPMember] = []
self.methods = [] self.methods: List[CPPMember] = []
self.prefix = 'p' self.prefix = 'p'
def add_method(self, name, f): def add_method(self, name: str, f: Function) -> None:
self.methods.append(CPPMember(name, f, self.prefix, method=True)) self.methods.append(CPPMember(name, f, self.prefix, method=True))
def add_constructor(self, name, f): def add_constructor(self, name: str, f: Function) -> None:
self.constructors.append(CPPMember(name, f, self.prefix, method=True)) self.constructors.append(CPPMember(name, f, self.prefix, method=True))
def generate_methods(self): def generate_methods(self) -> str:
return '\n '.join([m.generate_method() for m in self.methods]) return '\n '.join([m.generate_method() for m in self.methods])
def generate_constructors(self): def generate_constructors(self) -> str:
return '\n '.join( return '\n '.join(
[m.generate_constructor(self.name) for m in self.constructors]) [m.generate_constructor(self.name) for m in self.constructors])
def substitute(self, s, **kwargs): def substitute(self, s: Union[string.Template, str], **kwargs) -> str:
t = s t = string.Template(s) if isinstance(s, str) else s
if isinstance(s, str):
t = string.Template(s)
destroy = self.ctype + '_destroy' destroy = self.ctype + '_destroy'
return t.safe_substitute(name=self.name, return t.safe_substitute(name=self.name,
ctype=self.ctype, ctype=self.ctype,
destroy=destroy, destroy=destroy,
**kwargs) **kwargs)
def generate(self): def generate(self) -> str:
return self.substitute( return self.substitute(
cpp_class_template, cpp_class_template,
constructors=self.substitute(self.generate_constructors()), constructors=self.substitute(self.generate_constructors()),
methods=self.substitute(self.generate_methods())) methods=self.substitute(self.generate_methods()))
def params(virtual=None, **kwargs): def params(virtual: Optional[Dict[str, str]] = None,
**kwargs) -> List[Parameter]:
result = [] result = []
for name in virtual or {}: v: Dict[str, str] = virtual or {}
result.append(Parameter(name, virtual[name])) for name in v:
result.append(Parameter(name, v[name]))
for name in kwargs: for name in kwargs:
result.append(Parameter(name, kwargs[name])) result.append(Parameter(name, kwargs[name]))
return result return result
def add_function(name, *args, **kwargs): def add_function(name: str, *args, **kwargs) -> Function:
f = Function(name, *args, **kwargs) f = Function(name, *args, **kwargs)
functions.append(f) functions.append(f)
return f return f
def once(f): def once(f: Callable) -> Any:
@wraps(f) @wraps(f)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
if not decorated.has_run: if not decorated.has_run:
decorated.has_run = True decorated.has_run = True
return f(*args, **kwargs) return f(*args, **kwargs)
decorated.has_run = False d: Any = decorated
return decorated d.has_run = False
return d
@once @once
def process_functions(): def process_functions() -> None:
for f in functions: for f in functions:
f.update() f.update()
def generate_lines(p): def generate_lines(p: List[str]) -> str:
return '\n'.join(p) return '\n'.join(p)
def generate_c_header(): def generate_c_header() -> str:
process_functions() process_functions()
return generate_lines(c_header_preamble + return generate_lines(
[f.cfunction.generate_header() for f in functions]) c_header_preamble +
[f.get_cfunction().generate_header() for f in functions])
def generate_c_api_body(): def generate_c_api_body() -> str:
process_functions() process_functions()
return generate_lines(c_api_body_preamble + return generate_lines(
[f.cfunction.generate_body() for f in functions]) c_api_body_preamble +
[f.get_cfunction().generate_body() for f in functions])
def generate_cpp_header(): def generate_cpp_header() -> str:
process_functions() process_functions()
return generate_lines(cpp_header_preamble + return generate_lines(cpp_header_preamble +
[c.generate() for c in cpp_classes]) [c.generate() for c in cpp_classes])
def cwrap(name): def cwrap(name: str) -> Callable:
def with_cwrap(f): def with_cwrap(f):
type_map[name] = f type_map[name] = f
...@@ -677,13 +708,17 @@ protected: ...@@ -677,13 +708,17 @@ protected:
@once @once
def add_handle_preamble(): def add_handle_preamble() -> None:
c_api_body_preamble.append(handle_preamble) c_api_body_preamble.append(handle_preamble)
cpp_header_preamble.append( cpp_header_preamble.append(
string.Template(cpp_handle_preamble).substitute(success=success_type)) string.Template(cpp_handle_preamble).substitute(success=success_type))
def add_handle(name, ctype, cpptype, destroy=None, ref=None): def add_handle(name: str,
ctype: str,
cpptype: str,
destroy: Optional[str] = None,
ref: Optional[bool] = None) -> None:
opaque_type = ctype + '_t' opaque_type = ctype + '_t'
def handle_wrap(p): def handle_wrap(p):
...@@ -718,8 +753,12 @@ def add_handle(name, ctype, cpptype, destroy=None, ref=None): ...@@ -718,8 +753,12 @@ def add_handle(name, ctype, cpptype, destroy=None, ref=None):
@cwrap('std::vector') @cwrap('std::vector')
def vector_c_wrap(p): def vector_c_wrap(p: Parameter) -> None:
t = p.type.inner_type().add_pointer() inner = p.type.inner_type()
# Not a generic type
if not inner:
return
t = inner.add_pointer()
if p.returns: if p.returns:
if p.type.is_reference(): if p.type.is_reference():
if p.type.is_const(): if p.type.is_const():
...@@ -747,7 +786,7 @@ def vector_c_wrap(p): ...@@ -747,7 +786,7 @@ def vector_c_wrap(p):
@cwrap('std::string') @cwrap('std::string')
def string_c_wrap(p): def string_c_wrap(p: Parameter) -> None:
t = Type('char*') t = Type('char*')
if p.returns: if p.returns:
if p.type.is_reference(): if p.type.is_reference():
...@@ -771,7 +810,11 @@ def string_c_wrap(p): ...@@ -771,7 +810,11 @@ def string_c_wrap(p):
class Handle: class Handle:
def __init__(self, name, ctype, cpptype, ref=None): def __init__(self,
name: str,
ctype: str,
cpptype: str,
ref: Optional[bool] = None) -> None:
self.name = name self.name = name
self.ctype = ctype self.ctype = ctype
self.cpptype = cpptype self.cpptype = cpptype
...@@ -779,17 +822,21 @@ class Handle: ...@@ -779,17 +822,21 @@ class Handle:
add_handle(name, ctype, cpptype, ref=ref) add_handle(name, ctype, cpptype, ref=ref)
cpp_type_map[cpptype] = name cpp_type_map[cpptype] = name
def cname(self, name): def cname(self, name: str) -> str:
return self.ctype + '_' + name return self.ctype + '_' + name
def substitute(self, s, **kwargs): def substitute(self, s: str, **kwargs) -> str:
return Template(s).safe_substitute(name=self.name, return Template(s).safe_substitute(name=self.name,
ctype=self.ctype, ctype=self.ctype,
cpptype=self.cpptype, cpptype=self.cpptype,
**kwargs) **kwargs)
def constructor(self, name, params=None, fname=None, invoke=None, def constructor(self,
**kwargs): name: str,
params: Optional[List[Parameter]] = None,
fname: Optional[str] = None,
invoke: Optional[str] = None,
**kwargs) -> 'Handle':
create = self.substitute('allocate<${cpptype}>($@)') create = self.substitute('allocate<${cpptype}>($@)')
if fname: if fname:
create = self.substitute('allocate<${cpptype}>(${fname}($@))', create = self.substitute('allocate<${cpptype}>(${fname}($@))',
...@@ -805,13 +852,13 @@ class Handle: ...@@ -805,13 +852,13 @@ class Handle:
return self return self
def method(self, def method(self,
name, name: str,
params=None, params: Optional[List[Parameter]] = None,
fname=None, fname: Optional[str] = None,
invoke=None, invoke: Optional[str] = None,
cpp_name=None, cpp_name: Optional[str] = None,
const=None, const: Optional[bool] = None,
**kwargs): **kwargs) -> 'Handle':
cpptype = self.cpptype cpptype = self.cpptype
if const: if const:
cpptype = Type(cpptype).add_const().str() cpptype = Type(cpptype).add_const().str()
...@@ -832,11 +879,14 @@ class Handle: ...@@ -832,11 +879,14 @@ class Handle:
add_function(self.cname(name), params=params, **kwargs) add_function(self.cname(name), params=params, **kwargs)
return self return self
def add_cpp_class(self): def add_cpp_class(self) -> None:
cpp_classes.append(self.cpp_class) cpp_classes.append(self.cpp_class)
def handle(ctype, cpptype, name=None, ref=None): def handle(ctype: str,
cpptype: str,
name: Optional[str] = None,
ref: Optional[bool] = None) -> Callable:
def with_handle(f): def with_handle(f):
n = name or f.__name__ n = name or f.__name__
h = Handle(n, ctype, cpptype, ref=ref) h = Handle(n, ctype, cpptype, ref=ref)
...@@ -865,10 +915,10 @@ def template_eval(template, **kwargs): ...@@ -865,10 +915,10 @@ def template_eval(template, **kwargs):
return template return template
def run(): def run(args: List[str]) -> None:
runpy.run_path(sys.argv[1]) runpy.run_path(args[0])
if len(sys.argv) > 2: if len(args) > 1:
f = open(sys.argv[2]).read() f = open(args[1]).read()
r = template_eval(f) r = template_eval(f)
sys.stdout.write(r) sys.stdout.write(r)
else: else:
...@@ -879,4 +929,4 @@ def run(): ...@@ -879,4 +929,4 @@ def run():
if __name__ == "__main__": if __name__ == "__main__":
sys.modules['api'] = sys.modules['__main__'] sys.modules['api'] = sys.modules['__main__']
run() run(sys.argv[1:])
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
SRC_DIR=$DIR/../src SRC_DIR=$DIR/../src
ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "python3.6 $DIR/te.py $DIR/include/{} | clang-format-5.0 -style=file > $SRC_DIR/include/migraphx/{}" PYTHON=python3
if type -p python3.6 > /dev/null ; then
PYTHON=python3.6
fi
if type -p python3.8 > /dev/null ; then
PYTHON=python3.8
fi
ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "$PYTHON $DIR/te.py $DIR/include/{} | clang-format-5.0 -style=file > $SRC_DIR/include/migraphx/{}"
function api { function api {
python3.6 $DIR/api.py $SRC_DIR/api/migraphx.py $1 | clang-format-5.0 -style=file > $2 $PYTHON $DIR/api.py $SRC_DIR/api/migraphx.py $1 | clang-format-5.0 -style=file > $2
} }
api $DIR/api/migraphx.h $SRC_DIR/api/include/migraphx/migraphx.h api $DIR/api/migraphx.h $SRC_DIR/api/include/migraphx/migraphx.h
......
...@@ -21,7 +21,7 @@ def parse_args(): ...@@ -21,7 +21,7 @@ def parse_args():
description="Parser for MIGraphX ROCTX Markers") description="Parser for MIGraphX ROCTX Markers")
parser.add_argument('--json-path', parser.add_argument('--json-path',
type=str, type=str,
metavar='json_path', metavar='json-path',
help='Path to json file') help='Path to json file')
parser.add_argument('--out', parser.add_argument('--out',
type=str, type=str,
......
import os import os, sys
import numpy as np import numpy as np
import argparse import argparse
import onnx import onnx
...@@ -54,36 +54,112 @@ def read_pb_file(filename): ...@@ -54,36 +54,112 @@ def read_pb_file(filename):
tensor.ParseFromString(data_str) tensor.ParseFromString(data_str)
np_array = numpy_helper.to_array(tensor) np_array = numpy_helper.to_array(tensor)
return np_array return tensor.name, np_array
def wrapup_inputs(io_folder, parameter_names): def wrapup_inputs(io_folder, param_names):
index = 0
param_map = {} param_map = {}
for param_name in parameter_names: data_array = []
file_name = io_folder + '/input_' + str(index) + '.pb' name_array = []
data = read_pb_file(file_name) for i in range(len(param_names)):
param_map[param_name] = data file_name = io_folder + '/input_' + str(i) + '.pb'
index = index + 1 name, data = read_pb_file(file_name)
param_map[name] = data
data_array.append(data)
if name:
name_array.append(name)
if len(name_array) < len(data_array):
param_map = {}
for i in range(len(param_names)):
param_map[param_names[i]] = data_array[i]
return param_map
for name in param_names:
if not name in param_map.keys():
print("Input {} does not exist!".format(name))
sys.exit()
return param_map return param_map
def read_outputs(io_folder, out_num): def read_outputs(io_folder, out_names):
outputs = [] outputs = []
for i in range(out_num): data_array = []
name_array = []
for i in range(len(out_names)):
file_name = io_folder + '/output_' + str(i) + '.pb' file_name = io_folder + '/output_' + str(i) + '.pb'
data = read_pb_file(file_name) name, data = read_pb_file(file_name)
outputs.append(data) data_array.append(data)
if name:
name_array.append(name)
if len(name_array) < len(data_array):
return data_array
for name in out_names:
index = name_array.index(name)
outputs.append(data_array[index])
return outputs return outputs
def model_parameter_names(model_file_name):
with open(model_file_name, 'rb') as pfile:
data_str = pfile.read()
model_proto = onnx.ModelProto()
model_proto.ParseFromString(data_str)
init_names = set([(i.name) for i in model_proto.graph.initializer])
param_names = [
input.name for input in model_proto.graph.input
if input.name not in init_names
]
return param_names
def model_output_names(model_file_name):
with open(model_file_name, 'rb') as pfile:
data_str = pfile.read()
model_proto = onnx.ModelProto()
model_proto.ParseFromString(data_str)
output_names = [out.name for out in model_proto.graph.output]
return output_names
def get_input_shapes(sample_case, param_names):
param_shape_map = {}
name_array = []
shape_array = []
for i in range(len(param_names)):
file_name = sample_case + '/input_' + str(i) + '.pb'
name, data = read_pb_file(file_name)
param_shape_map[name] = data.shape
shape_array.append(data.shape)
if name:
name_array.append(name)
if len(name_array) < len(shape_array):
param_shape_map = {}
for i in range(len(param_names)):
param_shape_map[param_names[i]] = shape_array[i]
return param_shape_map
for name in param_names:
if not name in param_shape_map:
print("Input {} does not exist!".format(name))
sys.exit()
return param_shape_map
def run_one_case(model, param_map): def run_one_case(model, param_map):
# convert np array to model argument # convert np array to model argument
pp = {} pp = {}
for key, val in param_map.items(): for key, val in param_map.items():
print("input = {}".format(val))
pp[key] = migraphx.argument(val) pp[key] = migraphx.argument(val)
# run the model # run the model
...@@ -106,12 +182,11 @@ def check_correctness(gold_outputs, outputs, rtol=1e-3, atol=1e-3): ...@@ -106,12 +182,11 @@ def check_correctness(gold_outputs, outputs, rtol=1e-3, atol=1e-3):
out_num = len(gold_outputs) out_num = len(gold_outputs)
ret = True ret = True
for i in range(out_num): for i in range(out_num):
print("Expected value: \n{}".format(gold_outputs[i]))
print("Actual value: \n{}".format(outputs[i]))
if not np.allclose(gold_outputs[i], outputs[i], rtol, atol): if not np.allclose(gold_outputs[i], outputs[i], rtol, atol):
print("Output {} is incorrect ...".format(i)) print("\nOutput {} is incorrect ...".format(i))
print("Expected value: \n{}".format(gold_outputs[i])) print("Expected value: \n{}".format(gold_outputs[i]))
print("Actual value: \n{}".format(outputs[i])) print("......")
print("Actual value: \n{}\n".format(outputs[i]))
ret = False ret = False
return ret return ret
...@@ -142,21 +217,32 @@ def main(): ...@@ -142,21 +217,32 @@ def main():
# get model full path # get model full path
model_name = get_model_name(test_loc) model_name = get_model_name(test_loc)
model_path_name = test_loc + '/' + model_name model_path_name = test_loc + '/' + model_name
# read and compile model
model = migraphx.parse_onnx(model_path_name)
param_names = model.get_parameter_names()
output_shapes = model.get_output_shapes()
model.compile(migraphx.get_target(target)) # get param names
param_names = model_parameter_names(model_path_name)
# get output names
output_names = model_output_names(model_path_name)
# get test cases # get test cases
cases = get_test_cases(test_loc) cases = get_test_cases(test_loc)
sample_case = test_loc + '/' + cases[0]
param_shapes = get_input_shapes(sample_case, param_names)
for name, dims in param_shapes.items():
print("Input: {}, shape: {}".format(name, dims))
print()
# read and compile model
model = migraphx.parse_onnx(model_path_name, map_input_dims=param_shapes)
model.compile(migraphx.get_target(target))
# get test cases
case_num = len(cases) case_num = len(cases)
correct_num = 0 correct_num = 0
for case_name in cases: for case_name in cases:
io_folder = test_loc + '/' + case_name io_folder = test_loc + '/' + case_name
input_data = wrapup_inputs(io_folder, param_names) input_data = wrapup_inputs(io_folder, param_names)
gold_output_data = read_outputs(io_folder, len(output_shapes)) gold_outputs = read_outputs(io_folder, output_names)
# if input shape is different from model shape, reload and recompile # if input shape is different from model shape, reload and recompile
# model # model
...@@ -170,7 +256,7 @@ def main(): ...@@ -170,7 +256,7 @@ def main():
output_data = run_one_case(model, input_data) output_data = run_one_case(model, input_data)
# check output correctness # check output correctness
ret = check_correctness(gold_output_data, output_data) ret = check_correctness(gold_outputs, output_data)
if ret: if ret:
correct_num += 1 correct_num += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment