Commit b076d0f4 authored by Paul's avatar Paul
Browse files

Merge branch 'jit-layernorm-merge' into bert-opt-layernorm

parents 03c6967e d705e483
No preview for this file type
...@@ -3378,13 +3378,127 @@ TEST_CASE(nms_test) ...@@ -3378,13 +3378,127 @@ TEST_CASE(nms_test)
auto st = mm->add_parameter("score_threshold", sst); auto st = mm->add_parameter("score_threshold", sst);
auto ret = mm->add_instruction( auto ret = mm->add_instruction(
migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}), b, s, mo, iou, st); migraphx::make_op("nonmaxsuppression", {{"center_point_box", true}}), b, s, mo, iou, st);
mm->add_return({ret}); mm->add_return({ret});
auto prog = migraphx::parse_onnx("nms_test.onnx"); auto prog = migraphx::parse_onnx("nms_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(nms_dynamic_batch_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::float_type, {{1, 10, 0}, {6, 6, 0}, {4, 4, 0}}};
auto b = mm->add_parameter("boxes", sb);
migraphx::shape ss{migraphx::shape::float_type, {{1, 10, 0}, {1, 1, 0}, {6, 6, 0}}};
auto s = mm->add_parameter("scores", ss);
migraphx::shape smo{migraphx::shape::int64_type, {1}};
auto mo = mm->add_parameter("max_output_boxes_per_class", smo);
migraphx::shape siou{migraphx::shape::float_type, {1}};
auto iou = mm->add_parameter("iou_threshold", siou);
migraphx::shape sst{migraphx::shape::float_type, {1}};
auto st = mm->add_parameter("score_threshold", sst);
auto ret = mm->add_instruction(
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
b,
s,
mo,
iou,
st);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10, 0};
options.use_dyn_output = true;
auto prog = migraphx::parse_onnx("nms_dynamic_batch_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(nms_dynamic_boxes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::float_type, {{1, 1, 0}, {6, 20, 0}, {4, 4, 0}}};
auto b = mm->add_parameter("boxes", sb);
migraphx::shape ss{migraphx::shape::float_type, {{1, 1, 0}, {1, 1, 0}, {6, 20, 0}}};
auto s = mm->add_parameter("scores", ss);
migraphx::shape smo{migraphx::shape::int64_type, {1}};
auto mo = mm->add_parameter("max_output_boxes_per_class", smo);
migraphx::shape siou{migraphx::shape::float_type, {1}};
auto iou = mm->add_parameter("iou_threshold", siou);
migraphx::shape sst{migraphx::shape::float_type, {1}};
auto st = mm->add_parameter("score_threshold", sst);
auto ret = mm->add_instruction(
migraphx::make_op("nonmaxsuppression", {{"use_dyn_output", true}}), b, s, mo, iou, st);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {6, 20, 0};
options.use_dyn_output = true;
auto prog = migraphx::parse_onnx("nms_dynamic_boxes_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(nms_dynamic_classes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::float_type, {1, 6, 4}};
auto b = mm->add_parameter("boxes", sb);
migraphx::shape ss{migraphx::shape::float_type, {{1, 1, 0}, {1, 10, 0}, {6, 6, 0}}};
auto s = mm->add_parameter("scores", ss);
migraphx::shape smo{migraphx::shape::int64_type, {1}};
auto mo = mm->add_parameter("max_output_boxes_per_class", smo);
migraphx::shape siou{migraphx::shape::float_type, {1}};
auto iou = mm->add_parameter("iou_threshold", siou);
migraphx::shape sst{migraphx::shape::float_type, {1}};
auto st = mm->add_parameter("score_threshold", sst);
auto ret = mm->add_instruction(
migraphx::make_op("nonmaxsuppression", {{"use_dyn_output", true}}), b, s, mo, iou, st);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10, 0};
options.use_dyn_output = true;
auto prog = migraphx::parse_onnx("nms_dynamic_classes_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(nms_overwrite_use_dyn_output_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::float_type, {1, 6, 4}};
auto b = mm->add_parameter("boxes", sb);
migraphx::shape ss{migraphx::shape::float_type, {1, 1, 6}};
auto s = mm->add_parameter("scores", ss);
migraphx::shape smo{migraphx::shape::int64_type, {1}};
auto mo = mm->add_parameter("max_output_boxes_per_class", smo);
migraphx::shape siou{migraphx::shape::float_type, {1}};
auto iou = mm->add_parameter("iou_threshold", siou);
migraphx::shape sst{migraphx::shape::float_type, {1}};
auto st = mm->add_parameter("score_threshold", sst);
auto ret = mm->add_instruction(
migraphx::make_op("nonmaxsuppression", {{"use_dyn_output", true}}), b, s, mo, iou, st);
mm->add_return({ret});
migraphx::onnx_options options;
options.use_dyn_output = true;
auto prog = migraphx::parse_onnx("nms_use_dyn_output_false_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(nonzero_dynamic_test) TEST_CASE(nonzero_dynamic_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -1135,6 +1135,149 @@ TEST_CASE(multinomial) ...@@ -1135,6 +1135,149 @@ TEST_CASE(multinomial)
throws_shape(migraphx::make_op("multinomial", {{"dtype", dtype}}), s, s); throws_shape(migraphx::make_op("multinomial", {{"dtype", dtype}}), s, s);
} }
TEST_CASE(nms_shape)
{
// use_dyn_output == false
migraphx::shape boxes_s{migraphx::shape::float_type, {1, 6, 4}};
migraphx::shape scores_s{migraphx::shape::float_type, {1, 1, 6}};
migraphx::shape max_out_s{migraphx::shape::int64_type, {1}};
migraphx::shape iou_thres_s{migraphx::shape::float_type, {1}};
migraphx::shape score_thres_s{migraphx::shape::float_type, {1}};
migraphx::shape output_s{migraphx::shape::int64_type, {6, 3}};
expect_shape(output_s,
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", false}}),
boxes_s,
scores_s,
max_out_s,
iou_thres_s,
score_thres_s);
// use_dyn_output == true
output_s = {migraphx::shape::int64_type, {{0, 6, 0}, {3, 3, 0}}};
expect_shape(output_s,
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
scores_s,
max_out_s,
iou_thres_s,
score_thres_s);
// dynamic batches
boxes_s = {migraphx::shape::float_type, {{1, 3, 0}, {6, 6, 0}, {4, 4, 0}}};
scores_s = {migraphx::shape::float_type, {{1, 3, 0}, {1, 1, 0}, {6, 6, 0}}};
output_s = {migraphx::shape::int64_type, {{0, 18, 0}, {3, 3, 0}}};
expect_shape(output_s,
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
scores_s,
max_out_s,
iou_thres_s,
score_thres_s);
// dynamic num boxes
boxes_s = {migraphx::shape::float_type, {{1, 1, 0}, {6, 20, 0}, {4, 4, 0}}};
scores_s = {migraphx::shape::float_type, {{1, 1, 0}, {1, 1, 0}, {6, 20, 0}}};
output_s = {migraphx::shape::int64_type, {{0, 20, 0}, {3, 3, 0}}};
expect_shape(output_s,
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
scores_s,
max_out_s,
iou_thres_s,
score_thres_s);
// use_dyn_output false with dynamic input shape
throws_shape(migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", false}}),
boxes_s,
scores_s,
max_out_s,
iou_thres_s,
score_thres_s);
// dynamic classes
boxes_s = {migraphx::shape::float_type, {{1, 1, 0}, {6, 6, 0}, {4, 4, 0}}};
scores_s = {migraphx::shape::float_type, {{1, 1, 0}, {1, 3, 0}, {6, 6, 0}}};
output_s = {migraphx::shape::int64_type, {{0, 6, 0}, {3, 3, 0}}};
expect_shape(output_s,
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
scores_s,
max_out_s,
iou_thres_s,
score_thres_s);
// fixed mismatch batches
boxes_s = {migraphx::shape::float_type, {2, 6, 4}};
scores_s = {migraphx::shape::float_type, {1, 1, 6}};
throws_shape(migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
scores_s,
max_out_s,
iou_thres_s,
score_thres_s);
// fixed mismatch num boxes
boxes_s = {migraphx::shape::float_type, {1, 6, 4}};
scores_s = {migraphx::shape::float_type, {1, 1, 4}};
throws_shape(migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
scores_s,
max_out_s,
iou_thres_s,
score_thres_s);
// dynamic mismatch batches
boxes_s = {migraphx::shape::float_type, {{1, 4, 0}, {6, 6, 0}, {4, 4, 0}}};
scores_s = {migraphx::shape::float_type, {{2, 8, 0}, {1, 1, 0}, {6, 6, 0}}};
throws_shape(migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
scores_s,
max_out_s,
iou_thres_s,
score_thres_s);
// dynamic mismatch num boxes
boxes_s = {migraphx::shape::float_type, {{1, 1, 0}, {6, 8, 0}, {4, 4, 0}}};
scores_s = {migraphx::shape::float_type, {{1, 1, 0}, {1, 1, 0}, {3, 9, 0}}};
throws_shape(migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
scores_s,
max_out_s,
iou_thres_s,
score_thres_s);
// dynamic number of classes, fixed boxes_s, mismatch batches
boxes_s = {migraphx::shape::float_type, {1, 6, 4}};
scores_s = {migraphx::shape::float_type, {{1, 3, 0}, {1, 3, 0}, {6, 6, 0}}};
throws_shape(migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
scores_s,
max_out_s,
iou_thres_s,
score_thres_s);
// dynamic number of classes, fixed boxes_s, mismatch num boxes
boxes_s = {migraphx::shape::float_type, {1, 6, 4}};
scores_s = {migraphx::shape::float_type, {{1, 1, 0}, {1, 3, 0}, {4, 8, 0}}};
throws_shape(migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
scores_s,
max_out_s,
iou_thres_s,
score_thres_s);
}
TEST_CASE(pooling_shape) TEST_CASE(pooling_shape)
{ {
migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}}; migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
......
...@@ -3624,6 +3624,174 @@ TEST_CASE(neg_test) ...@@ -3624,6 +3624,174 @@ TEST_CASE(neg_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(nms_dynamic_out_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape boxes_s{migraphx::shape::float_type, {1, 6, 4}};
std::vector<float> boxes_vec = {0.5, 0.5, 1.0, 1.0, 0.5, 0.6, 1.0, 1.0, 0.5, 0.4, 1.0, 1.0,
0.5, 10.5, 1.0, 1.0, 0.5, 10.6, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0};
migraphx::shape scores_s{migraphx::shape::float_type, {1, 1, 6}};
std::vector<float> scores_vec = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
auto boxes_l = mm->add_literal(migraphx::literal(boxes_s, boxes_vec));
auto scores_l = mm->add_literal(migraphx::literal(scores_s, scores_vec));
auto max_out_l = mm->add_literal(int64_t{4});
auto iou_threshold = mm->add_literal(0.5f);
auto score_threshold = mm->add_literal(0.0f);
auto r = mm->add_instruction(
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_l,
scores_l,
max_out_l,
iou_threshold,
score_threshold);
mm->add_return({r});
p.compile(migraphx::ref::target{});
auto output = p.eval({}).back();
std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5};
EXPECT(migraphx::verify_range(result, gold));
}
TEST_CASE(nms_dynamic_batch_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape boxes_s{migraphx::shape::float_type, {{1, 3, 0}, {6, 6, 0}, {4, 4, 0}}};
migraphx::shape scores_s{migraphx::shape::float_type, {{1, 3, 0}, {1, 1, 0}, {6, 6, 0}}};
auto boxes_p = mm->add_parameter("boxes", boxes_s);
auto scores_p = mm->add_parameter("scores", scores_s);
auto max_out_l = mm->add_literal(int64_t{4});
auto iou_threshold = mm->add_literal(0.5f);
auto score_threshold = mm->add_literal(0.0f);
auto r = mm->add_instruction(
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_p,
scores_p,
max_out_l,
iou_threshold,
score_threshold);
mm->add_return({r});
p.compile(migraphx::ref::target{});
std::vector<float> boxes_vec = {0.5, 0.5, 1.0, 1.0, 0.5, 0.6, 1.0, 1.0, 0.5, 0.4, 1.0, 1.0,
0.5, 10.5, 1.0, 1.0, 0.5, 10.6, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0,
0.5, 0.5, 1.0, 1.0, 0.5, 0.6, 1.0, 1.0, 0.5, 0.4, 1.0, 1.0,
0.5, 10.5, 1.0, 1.0, 0.5, 10.6, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0};
std::vector<float> scores_vec = {
0.9, 0.75, 0.6, 0.95, 0.5, 0.3, 0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 6, 4}};
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {2, 1, 6}};
migraphx::parameter_map params0;
params0["boxes"] = migraphx::argument(input_fixed_shape0, boxes_vec.data());
params0["scores"] = migraphx::argument(input_fixed_shape1, scores_vec.data());
auto output = p.eval(params0).back();
std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 1, 0, 3, 1, 0, 0, 1, 0, 5};
EXPECT(migraphx::verify_range(result, gold));
}
TEST_CASE(nms_dynamic_boxes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape boxes_s{migraphx::shape::float_type, {{1, 1, 0}, {4, 20, 0}, {4, 4, 0}}};
migraphx::shape scores_s{migraphx::shape::float_type, {{1, 1, 0}, {1, 1, 0}, {4, 20, 0}}};
auto boxes_p = mm->add_parameter("boxes", boxes_s);
auto scores_p = mm->add_parameter("scores", scores_s);
auto max_out_l = mm->add_literal(int64_t{4});
auto iou_threshold = mm->add_literal(0.5f);
auto score_threshold = mm->add_literal(0.0f);
auto r = mm->add_instruction(
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_p,
scores_p,
max_out_l,
iou_threshold,
score_threshold);
mm->add_return({r});
p.compile(migraphx::ref::target{});
std::vector<float> boxes_vec = {0.5, 0.5, 1.0, 1.0, 0.5, 0.6, 1.0, 1.0, 0.5, 0.4, 1.0, 1.0,
0.5, 10.5, 1.0, 1.0, 0.5, 10.6, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0};
std::vector<float> scores_vec = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {1, 6, 4}};
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {1, 1, 6}};
migraphx::parameter_map params0;
params0["boxes"] = migraphx::argument(input_fixed_shape0, boxes_vec.data());
params0["scores"] = migraphx::argument(input_fixed_shape1, scores_vec.data());
auto output = p.eval(params0).back();
std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5};
EXPECT(migraphx::verify_range(result, gold));
}
TEST_CASE(nms_dynamic_classes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape boxes_s{migraphx::shape::float_type, {{1, 1, 0}, {6, 6, 0}, {4, 4, 0}}};
migraphx::shape scores_s{migraphx::shape::float_type, {{1, 1, 0}, {1, 3, 0}, {6, 6, 0}}};
auto boxes_p = mm->add_parameter("boxes", boxes_s);
auto scores_p = mm->add_parameter("scores", scores_s);
auto max_out_l = mm->add_literal(int64_t{2});
auto iou_threshold = mm->add_literal(0.5f);
auto score_threshold = mm->add_literal(0.0f);
auto r = mm->add_instruction(
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_p,
scores_p,
max_out_l,
iou_threshold,
score_threshold);
mm->add_return({r});
p.compile(migraphx::ref::target{});
std::vector<float> boxes_vec = {0.0, 0.0, 1.0, 1.0, 0.0, 0.1, 1.0, 1.1,
0.0, -0.1, 1.0, 0.9, 0.0, 10.0, 1.0, 11.0,
0.0, 10.1, 1.0, 11.1, 0.0, 100.0, 1.0, 101.0};
std::vector<float> scores_vec = {
0.9, 0.75, 0.6, 0.95, 0.5, 0.3, 0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {1, 6, 4}};
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {1, 2, 6}};
migraphx::parameter_map params0;
params0["boxes"] = migraphx::argument(input_fixed_shape0, boxes_vec.data());
params0["scores"] = migraphx::argument(input_fixed_shape1, scores_vec.data());
auto output = p.eval(params0).back();
std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 1, 3, 0, 1, 0};
EXPECT(migraphx::verify_range(result, gold));
}
TEST_CASE(nms_not_center_test) TEST_CASE(nms_not_center_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -3642,7 +3810,9 @@ TEST_CASE(nms_not_center_test) ...@@ -3642,7 +3810,9 @@ TEST_CASE(nms_not_center_test)
auto iou_threshold = mm->add_literal(0.5f); auto iou_threshold = mm->add_literal(0.5f);
auto score_threshold = mm->add_literal(0.0f); auto score_threshold = mm->add_literal(0.0f);
auto r = mm->add_instruction(migraphx::make_op("nonmaxsuppression"), // set use_dyn_output back to false in operator map
auto r =
mm->add_instruction(migraphx::make_op("nonmaxsuppression", {{"use_dyn_output", false}}),
boxes_l, boxes_l,
scores_l, scores_l,
max_out_l, max_out_l,
...@@ -3675,7 +3845,8 @@ TEST_CASE(nms_test) ...@@ -3675,7 +3845,8 @@ TEST_CASE(nms_test)
auto iou_threshold = mm->add_literal(0.5f); auto iou_threshold = mm->add_literal(0.5f);
auto score_threshold = mm->add_literal(0.0f); auto score_threshold = mm->add_literal(0.0f);
auto r = mm->add_instruction(migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}), auto r =
mm->add_instruction(migraphx::make_op("nonmaxsuppression", {{"center_point_box", true}}),
boxes_l, boxes_l,
scores_l, scores_l,
max_out_l, max_out_l,
...@@ -3712,7 +3883,8 @@ TEST_CASE(nms_transpose1_test) ...@@ -3712,7 +3883,8 @@ TEST_CASE(nms_transpose1_test)
auto transpose_boxes = mm->add_instruction( auto transpose_boxes = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), t_boxes_l); migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), t_boxes_l);
auto r = mm->add_instruction(migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}), auto r =
mm->add_instruction(migraphx::make_op("nonmaxsuppression", {{"center_point_box", true}}),
transpose_boxes, transpose_boxes,
scores_l, scores_l,
max_out_l, max_out_l,
...@@ -3749,7 +3921,8 @@ TEST_CASE(nms_transpose2_test) ...@@ -3749,7 +3921,8 @@ TEST_CASE(nms_transpose2_test)
auto transpose_boxes = mm->add_instruction( auto transpose_boxes = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), t_boxes_l); migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), t_boxes_l);
auto r = mm->add_instruction(migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}), auto r =
mm->add_instruction(migraphx::make_op("nonmaxsuppression", {{"center_point_box", true}}),
transpose_boxes, transpose_boxes,
scores_l, scores_l,
max_out_l, max_out_l,
......
...@@ -37,8 +37,10 @@ ...@@ -37,8 +37,10 @@
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/support_metric.hpp> #include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/supported_segments.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -64,12 +66,12 @@ struct target ...@@ -64,12 +66,12 @@ struct target
*/ */
context get_context() const; context get_context() const;
/** /**
* @brief Check how well an instruction is supported on a target with the given metric * @brief Get the ranges of instructions that are supported on a target
* @param ins Instruction to check if it's supported * @param module Module to check for supported instructions
* @param metric Used to define how the return value should be interpreted * @param metric Used to define how the quality of the support should be measured
* @return The value based on the chosen metric. Negative numbers mean unsupported * @return the supported segments of the graph
*/ */
float is_supported(T&, instruction_ref ins, support_metric m) const; supported_segments target_is_supported(T&, const_module_ref mod, support_metric metric) const;
/** /**
* @brief copy an argument to the current target. * @brief copy an argument to the current target.
* *
...@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg) ...@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg)
} }
template <class T> template <class T>
float target_is_supported(T&, instruction_ref, support_metric) supported_segments target_find_supported(T&, const_module_ref, support_metric)
{ {
return 0; return {};
} }
<% <%
...@@ -125,7 +127,7 @@ interface('target', ...@@ -125,7 +127,7 @@ interface('target',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True), virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True), virtual('get_context', returns='context', const=True),
virtual('is_supported', returns='float', ins='instruction_ref', m='support_metric', const=True, default='target_is_supported'), virtual('find_supported', returns='supported_segments', mod='const_module_ref', m='support_metric', const=True, default='target_find_supported'),
virtual('copy_to', virtual('copy_to',
returns = 'argument', returns = 'argument',
input = 'const argument&', input = 'const argument&',
......
...@@ -23,7 +23,9 @@ ...@@ -23,7 +23,9 @@
##################################################################################### #####################################################################################
import string, sys, re import string, sys, re
trivial = ['std::size_t', 'instruction_ref', 'support_metric'] trivial = [
'std::size_t', 'instruction_ref', 'support_metric', 'const_module_ref'
]
headers = ''' headers = '''
#include <algorithm> #include <algorithm>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment