Unverified Commit 35e5298e authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Handle all slice input variations (#2394)

parent 200c7038
......@@ -40,6 +40,8 @@ namespace op {
* 2. use_rank (default) vs use_len:
* `use_rank` sets the max value/index of the attribute as the rank of lens.
* `use_lens` sets the max value/index as the corresponding value in lens at the axes index.
* Uses the dynamic_dimension.max value for dynamic shapes. Returns the original vector
* (no normalization) if any of dynamic_dimension[axes] are not fixed.
* 3. `clip_min` vs. `not_clip_min` (default):
* Clip values less than the minimum to the minimum or not.
* 4. `include_min` vs. `exclude_min` (default):
......
This diff is collapsed.
......@@ -66,15 +66,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
if(input_shape.dynamic())
{
// return the unchanged `vec` if the dynamic_dimensions at `axes` are not fixed
if(std::any_of(axes.begin(), axes.end(), [&](auto ax) {
return not input_shape.dyn_dims().at(ax).is_fixed();
}))
{
return vec;
}
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
const auto& dd = input_shape.dyn_dims().at(i);
if(not dd.is_fixed())
{
MIGRAPHX_THROW(
"NORMALIZE_ATTR: 'use_lens' on a non-fixed dynamic dimension, axis=" +
std::to_string(i));
}
return dd.max;
return input_shape.dyn_dims().at(i).max;
});
}
else
......
......@@ -46,6 +46,9 @@ struct parse_slice : op_parser<parse_slice>
void always_insert(instruction_ref arg) { op_args.insert(op_args.begin(), arg); }
/**
* Either insert argument into `this->op_args` or return the constant value of the argument
*/
std::vector<int64_t> insert(instruction_ref arg)
{
std::vector<int64_t> result;
......
......@@ -3233,6 +3233,64 @@ TEST_CASE(slice_static_shape)
TEST_CASE(slice_var_inputs_static_shape0)
{
// attr ends and axes set; inputs are (data, input_starts)
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {0, 4}, {0, 4}}},
migraphx::make_op("slice", {{"ends", {2, 3}}, {"axes", {1, 2}}}),
input,
starts);
}
TEST_CASE(slice_var_inputs_static_mismatch_error0)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
throws_shape(
migraphx::make_op("slice", {{"ends", {2, 3, 4}}, {"axes", {0, 1, 2}}}), input, starts);
}
TEST_CASE(slice_var_inputs_static_shape1)
{
// attr starts and axes set; inputs are (data, input_ends)
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {0, 4}, {0, 4}}},
migraphx::make_op("slice", {{"starts", {0, 1}}, {"axes", {1, 2}}}),
input,
ends);
}
TEST_CASE(slice_var_inputs_static_mismatch_error1)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
throws_shape(
migraphx::make_op("slice", {{"starts", {0, 1, 2}}, {"axes", {0, 1, 2}}}), input, ends);
}
TEST_CASE(slice_var_inputs_static_shape2)
{
// attr starts and ends set; inputs are (data, input_axes)
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 3}, {0, 4}, {0, 4}}},
migraphx::make_op("slice", {{"starts", {0, 1}}, {"ends", {1, 2}}}),
input,
axes);
}
TEST_CASE(slice_var_inputs_static_mismatch_error2)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
throws_shape(
migraphx::make_op("slice", {{"starts", {0, 1, 2}}, {"ends", {3, 4, 4}}}), input, axes);
}
TEST_CASE(slice_var_inputs_static_shape3)
{
// attr axes set; inputs are (data, input_starts, input_ends)
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
......@@ -3243,7 +3301,57 @@ TEST_CASE(slice_var_inputs_static_shape0)
ends);
}
TEST_CASE(slice_var_inputs_static_shape1)
TEST_CASE(slice_var_inputs_static_mismatch_error3)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
throws_shape(migraphx::make_op("slice", {{"axes", {0, 1, 2}}}), input, starts, ends);
}
TEST_CASE(slice_var_inputs_static_shape4)
{
// attr ends set; inputs are (data, input_starts, input_axes)
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 3}, {0, 4}, {0, 4}}},
migraphx::make_op("slice", {{"ends", {3, 4}}}),
input,
starts,
axes);
}
TEST_CASE(slice_var_inputs_static_mismatch_error4)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
throws_shape(migraphx::make_op("slice", {{"ends", {3, 3, 3}}}), input, starts, axes);
}
TEST_CASE(slice_var_inputs_static_shape5)
{
// attr starts set; inputs are (data, input_ends, input_axes)
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 3}, {0, 4}, {0, 4}}},
migraphx::make_op("slice", {{"starts", {0, 2}}}),
input,
ends,
axes);
}
TEST_CASE(slice_var_inputs_static_mismatch_error5)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
throws_shape(migraphx::make_op("slice", {{"starts", {0, 1, 2}}}), input, ends, axes);
}
TEST_CASE(slice_var_inputs_static_shape6)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
......@@ -3257,7 +3365,7 @@ TEST_CASE(slice_var_inputs_static_shape1)
axes);
}
TEST_CASE(slice_var_inputs_static_error0)
TEST_CASE(slice_var_inputs_static_mismatch_error6)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
......@@ -3268,17 +3376,125 @@ TEST_CASE(slice_var_inputs_static_error0)
TEST_CASE(slice_var_inputs_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}};
// attr ends and axes set; inputs are (data, input_starts)
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 6}, {0, 6}, {0, 6}}},
migraphx::make_op("slice", {{"ends", {2, 3}}, {"axes", {1, 2}}}),
input,
starts);
}
TEST_CASE(slice_var_inputs_dyn_mismatch_error0)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
throws_shape(
migraphx::make_op("slice", {{"ends", {2, 3, 4}}, {"axes", {0, 1, 2}}}), input, starts);
}
TEST_CASE(slice_var_inputs_dyn_shape1)
{
// attr starts and axes set; inputs are (data, input_ends)
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 6}, {0, 4}, {0, 4}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 6}, {0, 6}, {0, 6}}},
migraphx::make_op("slice", {{"starts", {0, 1}}, {"axes", {1, 2}}}),
input,
ends);
}
TEST_CASE(slice_var_inputs_dyn_mismatch_error1)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
throws_shape(
migraphx::make_op("slice", {{"starts", {0, 1, 2}}, {"axes", {0, 1, 2}}}), input, ends);
}
TEST_CASE(slice_var_inputs_dyn_shape2)
{
// attr starts and ends set; inputs are (data, input_axes)
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 6}, {0, 6}, {0, 6}}},
migraphx::make_op("slice", {{"starts", {0, 1}}, {"ends", {8, 8}}}),
input,
axes);
}
TEST_CASE(slice_var_inputs_dyn_mismatch_error2)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
throws_shape(
migraphx::make_op("slice", {{"starts", {0, 1, 2}}, {"ends", {3, 4, 4}}}), input, axes);
}
TEST_CASE(slice_var_inputs_dyn_shape3)
{
// attr axes set; inputs are (data, input_starts, input_ends)
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 6}, {0, 6}, {0, 6}}},
migraphx::make_op("slice", {{"axes", {1, 2}}}),
input,
starts,
ends);
}
TEST_CASE(slice_var_inputs_dyn_shape1)
TEST_CASE(slice_var_inputs_dyn_mismatch_error3)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
throws_shape(migraphx::make_op("slice", {{"axes", {0, 1, 2}}}), input, starts, ends);
}
TEST_CASE(slice_var_inputs_dyn_shape4)
{
// attr ends set; inputs are (data, input_starts, input_axes)
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 6}, {0, 6}, {0, 6}}},
migraphx::make_op("slice", {{"ends", {3, 4}}}),
input,
starts,
axes);
}
TEST_CASE(slice_var_inputs_dyn_mismatch_error4)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
throws_shape(migraphx::make_op("slice", {{"ends", {3, 3, 3}}}), input, starts, axes);
}
TEST_CASE(slice_var_inputs_dyn_shape5)
{
// attr starts set; inputs are (data, input_ends, input_axes)
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 6}, {0, 6}, {0, 6}}},
migraphx::make_op("slice", {{"starts", {0, 2}}}),
input,
ends,
axes);
}
TEST_CASE(slice_var_inputs_dyn_mismatch_error5)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
throws_shape(migraphx::make_op("slice", {{"starts", {0, 1, 2}}}), input, ends, axes);
}
TEST_CASE(slice_var_inputs_dyn_shape6)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
......@@ -3292,6 +3508,15 @@ TEST_CASE(slice_var_inputs_dyn_shape1)
axes);
}
TEST_CASE(slice_var_inputs_dyn_mismatch_error6)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {3}};
throws_shape(migraphx::make_op("slice"), input, starts, ends, axes);
}
TEST_CASE(slice_dyn_shape0)
{
migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 7}, {2, 3}}};
......
......@@ -157,7 +157,169 @@ TEST_CASE(slice_var_inputs_static2)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn)
TEST_CASE(slice_var_inputs_dyn0)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}, {"ends", {10}}}), input, starts);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> start_data = {1};
params["input"] = migraphx::argument(s2, input_data.data());
params["starts"] = migraphx::argument(s1, start_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}, {"starts", {-5}}}), input, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> ends_data = {3};
params["input"] = migraphx::argument(s2, input_data.data());
params["ends"] = migraphx::argument(s1, ends_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
std::vector<int> results_vector(2 * 2 * 3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn2)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto axes = mm->add_parameter("axes", s1);
mm->add_instruction(migraphx::make_op("slice", {{"starts", {1}}, {"ends", {-1}}}), input, axes);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> axes_data = {2};
params["input"] = migraphx::argument(s2, input_data.data());
params["axes"] = migraphx::argument(s1, axes_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 4, 7, 10};
std::vector<int> results_vector(2 * 2 * 1);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn3)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), input, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> starts_data = {1};
std::vector<int> ends_data = {std::numeric_limits<int>::max()};
params["input"] = migraphx::argument(s2, input_data.data());
params["starts"] = migraphx::argument(s1, starts_data.data());
params["ends"] = migraphx::argument(s1, ends_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn4)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto axes = mm->add_parameter("axes", s1);
mm->add_instruction(migraphx::make_op("slice", {{"ends", {std::numeric_limits<int>::max()}}}),
input,
starts,
axes);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> starts_data = {1};
std::vector<int> axes_data = {2};
params["input"] = migraphx::argument(s2, input_data.data());
params["starts"] = migraphx::argument(s1, starts_data.data());
params["axes"] = migraphx::argument(s1, axes_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn5)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto ends = mm->add_parameter("ends", s1);
auto axes = mm->add_parameter("axes", s1);
mm->add_instruction(migraphx::make_op("slice", {{"starts", {-4}}}), input, ends, axes);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> ends_data = {2};
std::vector<int> axes_data = {2};
params["input"] = migraphx::argument(s2, input_data.data());
params["ends"] = migraphx::argument(s1, ends_data.data());
params["axes"] = migraphx::argument(s1, axes_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn6)
{
migraphx::program p;
auto* mm = p.get_main_module();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment